YOLO for African Wildlife Object Detection¶

In [1]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt

random.seed(33)

# Define dataset paths
dataset_path = "data/wildlife/valid"
image_folder = os.path.join(dataset_path, "images")
label_folder = os.path.join(dataset_path, "labels")

# Class labels dictionary
class_labels = {0: "Buffalo", 1: "Elephant", 2: "Rhino", 3: "Zebra"}
colors = {0: (255, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 255, 0)}  # Colors for each class

# Get list of image files
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]

# Select 6 random images
selected_images = random.sample(image_files, min(9, len(image_files)))

def draw_bboxes(image_path, label_path):
    """Draw bounding boxes on an image using YOLO annotations."""
    # Load image
    image = cv2.imread(image_path)
    height, width, _ = image.shape

    # Read YOLO label file
    with open(label_path, "r") as f:
        lines = f.readlines()

    # Draw bounding boxes
    for line in lines:
        class_id, x_center, y_center, bbox_width, bbox_height = map(float, line.strip().split())

        # Convert YOLO format (normalized) to pixel values
        x_center, y_center = int(x_center * width), int(y_center * height)
        bbox_width, bbox_height = int(bbox_width * width), int(bbox_height * height)

        x_min, y_min = x_center - bbox_width // 2, y_center - bbox_height // 2
        x_max, y_max = x_center + bbox_width // 2, y_center + bbox_height // 2

        # Draw rectangle
        color = colors[int(class_id)]
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)

        # Add class label
        label_text = f"{class_labels[int(class_id)]}"
        cv2.putText(image, label_text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for matplotlib

# Plot 6 images with bounding boxes
fig, axes = plt.subplots(3, 3, figsize=(15, 15))

for ax, image_file in zip(axes.flatten(), selected_images):
    image_path = os.path.join(image_folder, image_file)
    label_path = os.path.join(label_folder, image_file.replace(".jpg", ".txt").replace(".png", ".txt"))

    if os.path.exists(label_path):
        processed_image = draw_bboxes(image_path, label_path)
        ax.imshow(processed_image)
        ax.set_title(image_file)
        ax.axis("off")
    else:
        ax.axis("off")

plt.tight_layout()
plt.show()
No description has been provided for this image
In [2]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO

# Load YOLOv11 pre-trained model
model = YOLO("yolo11s.pt") 

# Function to draw predicted bounding boxes
def draw_predictions(image_path):
    """Runs YOLOv11 inference and draws bounding boxes on the image."""
    # Load image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Run YOLOv11 model on the image
    results = model(image_rgb)

    # Draw bounding boxes
    for result in results:
        boxes = result.boxes.xyxy  # Bounding boxes (x1, y1, x2, y2)
        scores = result.boxes.conf  # Confidence scores
        labels = result.boxes.cls  # Class labels

        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = map(int, box)  # Convert to integers
            label = model.names[int(labels[i])]
            score = scores[i]

            # Draw bounding box
            cv2.rectangle(image_rgb, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_rgb, f"{label} {score:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

    return image_rgb  # Return processed image

# Plot 6 images with YOLOv11 predictions
fig, axes = plt.subplots(3, 3, figsize=(15, 15))

for ax, image_file in zip(axes.flatten(), selected_images):
    image_path = os.path.join(image_folder, image_file)
    
    # Process image with YOLOv11
    predicted_image = draw_predictions(image_path)
    
    # Display image
    ax.imshow(predicted_image)
    ax.set_title(f"Predictions for {image_file}")
    ax.axis("off")

plt.tight_layout()
plt.show()

/home/kmcalist/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
0: 448x640 1 elephant, 90.7ms
Speed: 3.8ms preprocess, 90.7ms inference, 57.4ms postprocess per image at shape (1, 3, 448, 640)

0: 416x640 1 zebra, 87.5ms
Speed: 0.9ms preprocess, 87.5ms inference, 1.4ms postprocess per image at shape (1, 3, 416, 640)

0: 480x640 1 cow, 86.2ms
Speed: 2.3ms preprocess, 86.2ms inference, 0.7ms postprocess per image at shape (1, 3, 480, 640)

0: 512x640 2 zebras, 85.7ms
Speed: 2.6ms preprocess, 85.7ms inference, 0.9ms postprocess per image at shape (1, 3, 512, 640)

0: 448x640 1 person, 2 cows, 5.7ms
Speed: 1.1ms preprocess, 5.7ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640)

0: 448x640 1 cow, 5.5ms
Speed: 1.7ms preprocess, 5.5ms inference, 0.9ms postprocess per image at shape (1, 3, 448, 640)

0: 448x640 1 zebra, 5.2ms
Speed: 1.3ms preprocess, 5.2ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640)

0: 640x640 4 cows, 2 elephants, 5.5ms
Speed: 0.9ms preprocess, 5.5ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640)

0: 640x448 1 elephant, 86.3ms
Speed: 0.7ms preprocess, 86.3ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 448)
No description has been provided for this image

Fine Tuning¶

In [5]:
from ultralytics import YOLO

# Load a pre-trained YOLOv11 model
model = YOLO("yolo11s.pt")  # Use "yolov11s.pt" for better accuracy

# Train the model using your dataset
model.train(
    data="wildlife.yaml",  # Path to your dataset config
    epochs=25,                  # Number of training epochs
    imgsz=640,                  # Image size
    batch=32,                    # Batch size (adjust based on GPU memory)
    workers=12,                   # Number of CPU workers
    device="cuda"                # Use GPU if available, otherwise use "cpu"
)
New https://pypi.org/project/ultralytics/8.3.83 available 😃 Update with 'pip install -U ultralytics'
Ultralytics 8.3.82 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3090 Ti, 24245MiB)
engine/trainer: task=detect, mode=train, model=yolo11s.pt, data=wildlife.yaml, epochs=25, time=None, patience=100, batch=32, imgsz=640, save=True, save_period=-1, cache=False, device=cuda, workers=12, project=None, name=train, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=None, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, bgr=0.0, mosaic=1.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=randaugment, erasing=0.4, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, save_dir=runs/detect/train
Overriding model.yaml nc=80 with nc=4

                   from  n    params  module                                       arguments                     
  0                  -1  1       928  ultralytics.nn.modules.conv.Conv             [3, 32, 3, 2]                 
  1                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  2                  -1  1     26080  ultralytics.nn.modules.block.C3k2            [64, 128, 1, False, 0.25]     
  3                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
  4                  -1  1    103360  ultralytics.nn.modules.block.C3k2            [128, 256, 1, False, 0.25]    
  5                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
  6                  -1  1    346112  ultralytics.nn.modules.block.C3k2            [256, 256, 1, True]           
  7                  -1  1   1180672  ultralytics.nn.modules.conv.Conv             [256, 512, 3, 2]              
  8                  -1  1   1380352  ultralytics.nn.modules.block.C3k2            [512, 512, 1, True]           
  9                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]                 
 10                  -1  1    990976  ultralytics.nn.modules.block.C2PSA           [512, 512, 1]                 
 11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 13                  -1  1    443776  ultralytics.nn.modules.block.C3k2            [768, 256, 1, False]          
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  1    127680  ultralytics.nn.modules.block.C3k2            [512, 128, 1, False]          
 17                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
 18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 19                  -1  1    345472  ultralytics.nn.modules.block.C3k2            [384, 256, 1, False]          
 20                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  1   1511424  ultralytics.nn.modules.block.C3k2            [768, 512, 1, True]           
 23        [16, 19, 22]  1    820956  ultralytics.nn.modules.head.Detect           [4, [128, 256, 512]]          
YOLO11s summary: 181 layers, 9,429,340 parameters, 9,429,324 gradients, 21.6 GFLOPs

Transferred 493/499 items from pretrained weights
WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. start() got an unexpected keyword argument 'project_name'
WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. 

----NeptuneMissingApiTokenException-------------------------------------------

The Neptune client couldn't find your API token.

You can get it here:
    - https://app.neptune.ai/get_my_api_token

There are two options to add it:
    - specify it in your code
    - set an environment variable in your operating system.

CODE
Pass the token to the init_run() function via the api_token argument:
    neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN')

ENVIRONMENT VARIABLE (Recommended option)
or export or set an environment variable depending on your operating system:

    Linux/Unix
    In your terminal run:
        export NEPTUNE_API_TOKEN="YOUR_API_TOKEN"

    Windows
    In your CMD run:
        set NEPTUNE_API_TOKEN="YOUR_API_TOKEN"

and skip the api_token argument of the init_run() function:
    neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME')

You may also want to check the following docs pages:
    - https://docs.neptune.ai/setup/setting_api_token/

Need help?-> https://docs.neptune.ai/getting_help

TensorBoard: Start with 'tensorboard --logdir runs/detect/train', view at http://localhost:6006/
Freezing layer 'model.23.dfl.conv.weight'
AMP: running Automatic Mixed Precision (AMP) checks...
[neptune] [warning] NeptuneWarning: The following monitoring options are disabled by default in interactive sessions: 'capture_stdout', 'capture_stderr', 'capture_traceback', and 'capture_hardware_metrics'. To enable them, set each parameter to 'True' when initializing the run. The monitoring will continue until you call run.stop() or the kernel stops. Also note: Your source files can only be tracked if you pass the path(s) to the 'source_code' argument. For help, see the Neptune docs: https://docs.neptune.ai/logging/source_code/
AMP: checks passed ✅
train: Scanning /home/kmcalist/QTM447/Spring2025/Lectures/Lecture15/data/wildlife/train/labels.cache... 1052 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1052/1052 [00:00<?, ?it/s]
val: Scanning /home/kmcalist/QTM447/Spring2025/Lectures/Lecture15/data/wildlife/valid/labels.cache... 225 images, 0 backgrounds, 0 corrupt: 100%|██████████| 225/225 [00:00<?, ?it/s]
Plotting labels to runs/detect/train/labels.jpg... 
optimizer: 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
optimizer: AdamW(lr=0.00125, momentum=0.9) with parameter groups 81 weight(decay=0.0), 88 weight(decay=0.0005), 87 bias(decay=0.0)
TensorBoard: model graph visualization added ✅
Image sizes 640 train, 640 val
Using 12 dataloader workers
Logging results to runs/detect/train
Starting training for 25 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       1/25      8.42G     0.8596      2.283      1.256         98        640: 100%|██████████| 33/33 [00:06<00:00,  5.11it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:01<00:00,  3.56it/s]
                   all        225        379      0.604      0.604      0.691      0.472

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       2/25      8.18G      1.008      1.331      1.328        115        640: 100%|██████████| 33/33 [00:05<00:00,  6.07it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.54it/s]
                   all        225        379      0.608      0.555      0.595      0.375

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       3/25      8.48G      1.023      1.223      1.338        114        640: 100%|██████████| 33/33 [00:05<00:00,  6.28it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  5.38it/s]
                   all        225        379       0.48      0.335       0.32      0.175

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       4/25      8.06G      1.083      1.212      1.383        109        640: 100%|██████████| 33/33 [00:05<00:00,  6.19it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.50it/s]
                   all        225        379      0.655      0.209      0.219      0.126

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       5/25      8.35G      1.075      1.181      1.372        106        640: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]
                   all        225        379      0.475      0.452      0.433      0.255

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       6/25       8.2G      1.017      1.091      1.321        115        640: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.62it/s]
                   all        225        379      0.704       0.66      0.735      0.499

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       7/25      8.47G      1.019      1.076      1.333         97        640: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.45it/s]
                   all        225        379      0.815      0.541       0.72      0.516

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       8/25      8.19G     0.9524      1.011      1.281        127        640: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.49it/s]
                   all        225        379      0.849      0.682      0.809      0.573

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
       9/25      8.47G     0.9336     0.9513      1.265        142        640: 100%|██████████| 33/33 [00:05<00:00,  6.11it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.37it/s]
                   all        225        379      0.788       0.66      0.779      0.551

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      10/25      8.19G      0.902     0.8862      1.253        111        640: 100%|██████████| 33/33 [00:05<00:00,  6.19it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.61it/s]
                   all        225        379        0.8      0.743      0.822      0.604

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      11/25      8.47G     0.8912     0.8825      1.244        132        640: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.49it/s]
                   all        225        379      0.831      0.761      0.862      0.624

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      12/25      8.19G     0.8687     0.8668      1.221        110        640: 100%|██████████| 33/33 [00:05<00:00,  6.19it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.62it/s]
                   all        225        379      0.871       0.77      0.873      0.649

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      13/25      8.47G     0.8227     0.7707      1.188        122        640: 100%|██████████| 33/33 [00:05<00:00,  6.12it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.53it/s]
                   all        225        379      0.864      0.827      0.892      0.694

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      14/25       8.1G      0.811     0.7525      1.182        112        640: 100%|██████████| 33/33 [00:05<00:00,  6.06it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.41it/s]
                   all        225        379      0.812       0.81      0.875      0.676

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      15/25      8.47G     0.8019     0.7402      1.179        133        640: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.61it/s]
                   all        225        379       0.94      0.842      0.921      0.728

Closing dataloader mosaic

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      16/25      8.19G     0.7612     0.6495      1.171         40        640: 100%|██████████| 33/33 [00:05<00:00,  5.59it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.59it/s]
                   all        225        379      0.918      0.837      0.919      0.723

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      17/25      8.47G     0.7296      0.631      1.137         42        640: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
                   all        225        379      0.749      0.792      0.823      0.633

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      18/25      8.19G      0.695      0.599      1.109         54        640: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.62it/s]
                   all        225        379      0.902      0.789      0.903      0.724

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      19/25      8.47G     0.6755     0.5599        1.1         53        640: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]
                   all        225        379      0.929      0.848      0.925      0.757

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      20/25      8.19G     0.6638     0.5076      1.093         43        640: 100%|██████████| 33/33 [00:05<00:00,  6.20it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.57it/s]
                   all        225        379      0.864      0.863      0.919      0.758

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      21/25      8.47G      0.634     0.4936       1.06         51        640: 100%|██████████| 33/33 [00:05<00:00,  6.28it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.74it/s]
                   all        225        379      0.915      0.851      0.933      0.761

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      22/25      8.19G     0.6195     0.4777      1.062         53        640: 100%|██████████| 33/33 [00:05<00:00,  6.19it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.50it/s]
                   all        225        379      0.897      0.875      0.932      0.776

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      23/25      8.47G     0.6009     0.4635      1.045         43        640: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.73it/s]
                   all        225        379      0.944      0.882       0.95      0.797

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      24/25      8.19G     0.5726      0.428       1.03         46        640: 100%|██████████| 33/33 [00:05<00:00,  6.20it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.33it/s]
                   all        225        379      0.935      0.883      0.952        0.8

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
      25/25      8.47G     0.5576     0.3874      1.003         46        640: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
                   all        225        379      0.938      0.885      0.951      0.807

25 epochs completed in 0.048 hours.
Optimizer stripped from runs/detect/train/weights/last.pt, 19.2MB
Optimizer stripped from runs/detect/train/weights/best.pt, 19.2MB

Validating runs/detect/train/weights/best.pt...
Ultralytics 8.3.82 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3090 Ti, 24245MiB)
YOLO11s summary (fused): 100 layers, 9,414,348 parameters, 0 gradients, 21.3 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:00<00:00,  4.62it/s]
                   all        225        379      0.938      0.885      0.951      0.807
               buffalo         62         89      0.949      0.888      0.955      0.848
              elephant         53         91      0.894      0.868       0.94      0.761
                 rhino         55         85      0.941      0.943      0.966      0.867
                 zebra         59        114       0.97      0.842      0.943      0.751
Speed: 0.1ms preprocess, 1.2ms inference, 0.0ms loss, 0.7ms postprocess per image
Results saved to runs/detect/train
Out[5]:
ultralytics.utils.metrics.DetMetrics object with attributes:

ap_class_index: array([0, 1, 2, 3])
box: ultralytics.utils.metrics.Metric object
confusion_matrix: <ultralytics.utils.metrics.ConfusionMatrix object at 0x7835f50be4d0>
curves: ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)']
curves_results: [[array([          0,    0.001001,    0.002002,    0.003003,    0.004004,    0.005005,    0.006006,    0.007007,    0.008008,    0.009009,     0.01001,    0.011011,    0.012012,    0.013013,    0.014014,    0.015015,    0.016016,    0.017017,    0.018018,    0.019019,     0.02002,    0.021021,    0.022022,    0.023023,
          0.024024,    0.025025,    0.026026,    0.027027,    0.028028,    0.029029,     0.03003,    0.031031,    0.032032,    0.033033,    0.034034,    0.035035,    0.036036,    0.037037,    0.038038,    0.039039,     0.04004,    0.041041,    0.042042,    0.043043,    0.044044,    0.045045,    0.046046,    0.047047,
          0.048048,    0.049049,     0.05005,    0.051051,    0.052052,    0.053053,    0.054054,    0.055055,    0.056056,    0.057057,    0.058058,    0.059059,     0.06006,    0.061061,    0.062062,    0.063063,    0.064064,    0.065065,    0.066066,    0.067067,    0.068068,    0.069069,     0.07007,    0.071071,
          0.072072,    0.073073,    0.074074,    0.075075,    0.076076,    0.077077,    0.078078,    0.079079,     0.08008,    0.081081,    0.082082,    0.083083,    0.084084,    0.085085,    0.086086,    0.087087,    0.088088,    0.089089,     0.09009,    0.091091,    0.092092,    0.093093,    0.094094,    0.095095,
          0.096096,    0.097097,    0.098098,    0.099099,      0.1001,      0.1011,      0.1021,      0.1031,      0.1041,     0.10511,     0.10611,     0.10711,     0.10811,     0.10911,     0.11011,     0.11111,     0.11211,     0.11311,     0.11411,     0.11512,     0.11612,     0.11712,     0.11812,     0.11912,
           0.12012,     0.12112,     0.12212,     0.12312,     0.12412,     0.12513,     0.12613,     0.12713,     0.12813,     0.12913,     0.13013,     0.13113,     0.13213,     0.13313,     0.13413,     0.13514,     0.13614,     0.13714,     0.13814,     0.13914,     0.14014,     0.14114,     0.14214,     0.14314,
           0.14414,     0.14515,     0.14615,     0.14715,     0.14815,     0.14915,     0.15015,     0.15115,     0.15215,     0.15315,     0.15415,     0.15516,     0.15616,     0.15716,     0.15816,     0.15916,     0.16016,     0.16116,     0.16216,     0.16316,     0.16416,     0.16517,     0.16617,     0.16717,
           0.16817,     0.16917,     0.17017,     0.17117,     0.17217,     0.17317,     0.17417,     0.17518,     0.17618,     0.17718,     0.17818,     0.17918,     0.18018,     0.18118,     0.18218,     0.18318,     0.18418,     0.18519,     0.18619,     0.18719,     0.18819,     0.18919,     0.19019,     0.19119,
           0.19219,     0.19319,     0.19419,      0.1952,      0.1962,      0.1972,      0.1982,      0.1992,      0.2002,      0.2012,      0.2022,      0.2032,      0.2042,     0.20521,     0.20621,     0.20721,     0.20821,     0.20921,     0.21021,     0.21121,     0.21221,     0.21321,     0.21421,     0.21522,
           0.21622,     0.21722,     0.21822,     0.21922,     0.22022,     0.22122,     0.22222,     0.22322,     0.22422,     0.22523,     0.22623,     0.22723,     0.22823,     0.22923,     0.23023,     0.23123,     0.23223,     0.23323,     0.23423,     0.23524,     0.23624,     0.23724,     0.23824,     0.23924,
           0.24024,     0.24124,     0.24224,     0.24324,     0.24424,     0.24525,     0.24625,     0.24725,     0.24825,     0.24925,     0.25025,     0.25125,     0.25225,     0.25325,     0.25425,     0.25526,     0.25626,     0.25726,     0.25826,     0.25926,     0.26026,     0.26126,     0.26226,     0.26326,
           0.26426,     0.26527,     0.26627,     0.26727,     0.26827,     0.26927,     0.27027,     0.27127,     0.27227,     0.27327,     0.27427,     0.27528,     0.27628,     0.27728,     0.27828,     0.27928,     0.28028,     0.28128,     0.28228,     0.28328,     0.28428,     0.28529,     0.28629,     0.28729,
           0.28829,     0.28929,     0.29029,     0.29129,     0.29229,     0.29329,     0.29429,      0.2953,      0.2963,      0.2973,      0.2983,      0.2993,      0.3003,      0.3013,      0.3023,      0.3033,      0.3043,     0.30531,     0.30631,     0.30731,     0.30831,     0.30931,     0.31031,     0.31131,
           0.31231,     0.31331,     0.31431,     0.31532,     0.31632,     0.31732,     0.31832,     0.31932,     0.32032,     0.32132,     0.32232,     0.32332,     0.32432,     0.32533,     0.32633,     0.32733,     0.32833,     0.32933,     0.33033,     0.33133,     0.33233,     0.33333,     0.33433,     0.33534,
           0.33634,     0.33734,     0.33834,     0.33934,     0.34034,     0.34134,     0.34234,     0.34334,     0.34434,     0.34535,     0.34635,     0.34735,     0.34835,     0.34935,     0.35035,     0.35135,     0.35235,     0.35335,     0.35435,     0.35536,     0.35636,     0.35736,     0.35836,     0.35936,
           0.36036,     0.36136,     0.36236,     0.36336,     0.36436,     0.36537,     0.36637,     0.36737,     0.36837,     0.36937,     0.37037,     0.37137,     0.37237,     0.37337,     0.37437,     0.37538,     0.37638,     0.37738,     0.37838,     0.37938,     0.38038,     0.38138,     0.38238,     0.38338,
           0.38438,     0.38539,     0.38639,     0.38739,     0.38839,     0.38939,     0.39039,     0.39139,     0.39239,     0.39339,     0.39439,      0.3954,      0.3964,      0.3974,      0.3984,      0.3994,      0.4004,      0.4014,      0.4024,      0.4034,      0.4044,     0.40541,     0.40641,     0.40741,
           0.40841,     0.40941,     0.41041,     0.41141,     0.41241,     0.41341,     0.41441,     0.41542,     0.41642,     0.41742,     0.41842,     0.41942,     0.42042,     0.42142,     0.42242,     0.42342,     0.42442,     0.42543,     0.42643,     0.42743,     0.42843,     0.42943,     0.43043,     0.43143,
           0.43243,     0.43343,     0.43443,     0.43544,     0.43644,     0.43744,     0.43844,     0.43944,     0.44044,     0.44144,     0.44244,     0.44344,     0.44444,     0.44545,     0.44645,     0.44745,     0.44845,     0.44945,     0.45045,     0.45145,     0.45245,     0.45345,     0.45445,     0.45546,
           0.45646,     0.45746,     0.45846,     0.45946,     0.46046,     0.46146,     0.46246,     0.46346,     0.46446,     0.46547,     0.46647,     0.46747,     0.46847,     0.46947,     0.47047,     0.47147,     0.47247,     0.47347,     0.47447,     0.47548,     0.47648,     0.47748,     0.47848,     0.47948,
           0.48048,     0.48148,     0.48248,     0.48348,     0.48448,     0.48549,     0.48649,     0.48749,     0.48849,     0.48949,     0.49049,     0.49149,     0.49249,     0.49349,     0.49449,      0.4955,      0.4965,      0.4975,      0.4985,      0.4995,      0.5005,      0.5015,      0.5025,      0.5035,
            0.5045,     0.50551,     0.50651,     0.50751,     0.50851,     0.50951,     0.51051,     0.51151,     0.51251,     0.51351,     0.51451,     0.51552,     0.51652,     0.51752,     0.51852,     0.51952,     0.52052,     0.52152,     0.52252,     0.52352,     0.52452,     0.52553,     0.52653,     0.52753,
           0.52853,     0.52953,     0.53053,     0.53153,     0.53253,     0.53353,     0.53453,     0.53554,     0.53654,     0.53754,     0.53854,     0.53954,     0.54054,     0.54154,     0.54254,     0.54354,     0.54454,     0.54555,     0.54655,     0.54755,     0.54855,     0.54955,     0.55055,     0.55155,
           0.55255,     0.55355,     0.55455,     0.55556,     0.55656,     0.55756,     0.55856,     0.55956,     0.56056,     0.56156,     0.56256,     0.56356,     0.56456,     0.56557,     0.56657,     0.56757,     0.56857,     0.56957,     0.57057,     0.57157,     0.57257,     0.57357,     0.57457,     0.57558,
           0.57658,     0.57758,     0.57858,     0.57958,     0.58058,     0.58158,     0.58258,     0.58358,     0.58458,     0.58559,     0.58659,     0.58759,     0.58859,     0.58959,     0.59059,     0.59159,     0.59259,     0.59359,     0.59459,      0.5956,      0.5966,      0.5976,      0.5986,      0.5996,
            0.6006,      0.6016,      0.6026,      0.6036,      0.6046,     0.60561,     0.60661,     0.60761,     0.60861,     0.60961,     0.61061,     0.61161,     0.61261,     0.61361,     0.61461,     0.61562,     0.61662,     0.61762,     0.61862,     0.61962,     0.62062,     0.62162,     0.62262,     0.62362,
           0.62462,     0.62563,     0.62663,     0.62763,     0.62863,     0.62963,     0.63063,     0.63163,     0.63263,     0.63363,     0.63463,     0.63564,     0.63664,     0.63764,     0.63864,     0.63964,     0.64064,     0.64164,     0.64264,     0.64364,     0.64464,     0.64565,     0.64665,     0.64765,
           0.64865,     0.64965,     0.65065,     0.65165,     0.65265,     0.65365,     0.65465,     0.65566,     0.65666,     0.65766,     0.65866,     0.65966,     0.66066,     0.66166,     0.66266,     0.66366,     0.66466,     0.66567,     0.66667,     0.66767,     0.66867,     0.66967,     0.67067,     0.67167,
           0.67267,     0.67367,     0.67467,     0.67568,     0.67668,     0.67768,     0.67868,     0.67968,     0.68068,     0.68168,     0.68268,     0.68368,     0.68468,     0.68569,     0.68669,     0.68769,     0.68869,     0.68969,     0.69069,     0.69169,     0.69269,     0.69369,     0.69469,      0.6957,
            0.6967,      0.6977,      0.6987,      0.6997,      0.7007,      0.7017,      0.7027,      0.7037,      0.7047,     0.70571,     0.70671,     0.70771,     0.70871,     0.70971,     0.71071,     0.71171,     0.71271,     0.71371,     0.71471,     0.71572,     0.71672,     0.71772,     0.71872,     0.71972,
           0.72072,     0.72172,     0.72272,     0.72372,     0.72472,     0.72573,     0.72673,     0.72773,     0.72873,     0.72973,     0.73073,     0.73173,     0.73273,     0.73373,     0.73473,     0.73574,     0.73674,     0.73774,     0.73874,     0.73974,     0.74074,     0.74174,     0.74274,     0.74374,
           0.74474,     0.74575,     0.74675,     0.74775,     0.74875,     0.74975,     0.75075,     0.75175,     0.75275,     0.75375,     0.75475,     0.75576,     0.75676,     0.75776,     0.75876,     0.75976,     0.76076,     0.76176,     0.76276,     0.76376,     0.76476,     0.76577,     0.76677,     0.76777,
           0.76877,     0.76977,     0.77077,     0.77177,     0.77277,     0.77377,     0.77477,     0.77578,     0.77678,     0.77778,     0.77878,     0.77978,     0.78078,     0.78178,     0.78278,     0.78378,     0.78478,     0.78579,     0.78679,     0.78779,     0.78879,     0.78979,     0.79079,     0.79179,
           0.79279,     0.79379,     0.79479,      0.7958,      0.7968,      0.7978,      0.7988,      0.7998,      0.8008,      0.8018,      0.8028,      0.8038,      0.8048,     0.80581,     0.80681,     0.80781,     0.80881,     0.80981,     0.81081,     0.81181,     0.81281,     0.81381,     0.81481,     0.81582,
           0.81682,     0.81782,     0.81882,     0.81982,     0.82082,     0.82182,     0.82282,     0.82382,     0.82482,     0.82583,     0.82683,     0.82783,     0.82883,     0.82983,     0.83083,     0.83183,     0.83283,     0.83383,     0.83483,     0.83584,     0.83684,     0.83784,     0.83884,     0.83984,
           0.84084,     0.84184,     0.84284,     0.84384,     0.84484,     0.84585,     0.84685,     0.84785,     0.84885,     0.84985,     0.85085,     0.85185,     0.85285,     0.85385,     0.85485,     0.85586,     0.85686,     0.85786,     0.85886,     0.85986,     0.86086,     0.86186,     0.86286,     0.86386,
           0.86486,     0.86587,     0.86687,     0.86787,     0.86887,     0.86987,     0.87087,     0.87187,     0.87287,     0.87387,     0.87487,     0.87588,     0.87688,     0.87788,     0.87888,     0.87988,     0.88088,     0.88188,     0.88288,     0.88388,     0.88488,     0.88589,     0.88689,     0.88789,
           0.88889,     0.88989,     0.89089,     0.89189,     0.89289,     0.89389,     0.89489,      0.8959,      0.8969,      0.8979,      0.8989,      0.8999,      0.9009,      0.9019,      0.9029,      0.9039,      0.9049,     0.90591,     0.90691,     0.90791,     0.90891,     0.90991,     0.91091,     0.91191,
           0.91291,     0.91391,     0.91491,     0.91592,     0.91692,     0.91792,     0.91892,     0.91992,     0.92092,     0.92192,     0.92292,     0.92392,     0.92492,     0.92593,     0.92693,     0.92793,     0.92893,     0.92993,     0.93093,     0.93193,     0.93293,     0.93393,     0.93493,     0.93594,
           0.93694,     0.93794,     0.93894,     0.93994,     0.94094,     0.94194,     0.94294,     0.94394,     0.94494,     0.94595,     0.94695,     0.94795,     0.94895,     0.94995,     0.95095,     0.95195,     0.95295,     0.95395,     0.95495,     0.95596,     0.95696,     0.95796,     0.95896,     0.95996,
           0.96096,     0.96196,     0.96296,     0.96396,     0.96496,     0.96597,     0.96697,     0.96797,     0.96897,     0.96997,     0.97097,     0.97197,     0.97297,     0.97397,     0.97497,     0.97598,     0.97698,     0.97798,     0.97898,     0.97998,     0.98098,     0.98198,     0.98298,     0.98398,
           0.98498,     0.98599,     0.98699,     0.98799,     0.98899,     0.98999,     0.99099,     0.99199,     0.99299,     0.99399,     0.99499,       0.996,       0.997,       0.998,       0.999,           1]), array([[          1,           1,           1, ...,     0.18737,     0.18737,           0],
       [          1,           1,           1, ...,     0.14725,     0.14725,           0],
       [          1,           1,           1, ...,    0.046713,    0.023357,           0],
       [          1,           1,           1, ...,    0.015325,   0.0076623,           0]]), 'Recall', 'Precision'], [array([          0,    0.001001,    0.002002,    0.003003,    0.004004,    0.005005,    0.006006,    0.007007,    0.008008,    0.009009,     0.01001,    0.011011,    0.012012,    0.013013,    0.014014,    0.015015,    0.016016,    0.017017,    0.018018,    0.019019,     0.02002,    0.021021,    0.022022,    0.023023,
          0.024024,    0.025025,    0.026026,    0.027027,    0.028028,    0.029029,     0.03003,    0.031031,    0.032032,    0.033033,    0.034034,    0.035035,    0.036036,    0.037037,    0.038038,    0.039039,     0.04004,    0.041041,    0.042042,    0.043043,    0.044044,    0.045045,    0.046046,    0.047047,
          0.048048,    0.049049,     0.05005,    0.051051,    0.052052,    0.053053,    0.054054,    0.055055,    0.056056,    0.057057,    0.058058,    0.059059,     0.06006,    0.061061,    0.062062,    0.063063,    0.064064,    0.065065,    0.066066,    0.067067,    0.068068,    0.069069,     0.07007,    0.071071,
          0.072072,    0.073073,    0.074074,    0.075075,    0.076076,    0.077077,    0.078078,    0.079079,     0.08008,    0.081081,    0.082082,    0.083083,    0.084084,    0.085085,    0.086086,    0.087087,    0.088088,    0.089089,     0.09009,    0.091091,    0.092092,    0.093093,    0.094094,    0.095095,
          0.096096,    0.097097,    0.098098,    0.099099,      0.1001,      0.1011,      0.1021,      0.1031,      0.1041,     0.10511,     0.10611,     0.10711,     0.10811,     0.10911,     0.11011,     0.11111,     0.11211,     0.11311,     0.11411,     0.11512,     0.11612,     0.11712,     0.11812,     0.11912,
           0.12012,     0.12112,     0.12212,     0.12312,     0.12412,     0.12513,     0.12613,     0.12713,     0.12813,     0.12913,     0.13013,     0.13113,     0.13213,     0.13313,     0.13413,     0.13514,     0.13614,     0.13714,     0.13814,     0.13914,     0.14014,     0.14114,     0.14214,     0.14314,
           0.14414,     0.14515,     0.14615,     0.14715,     0.14815,     0.14915,     0.15015,     0.15115,     0.15215,     0.15315,     0.15415,     0.15516,     0.15616,     0.15716,     0.15816,     0.15916,     0.16016,     0.16116,     0.16216,     0.16316,     0.16416,     0.16517,     0.16617,     0.16717,
           0.16817,     0.16917,     0.17017,     0.17117,     0.17217,     0.17317,     0.17417,     0.17518,     0.17618,     0.17718,     0.17818,     0.17918,     0.18018,     0.18118,     0.18218,     0.18318,     0.18418,     0.18519,     0.18619,     0.18719,     0.18819,     0.18919,     0.19019,     0.19119,
           0.19219,     0.19319,     0.19419,      0.1952,      0.1962,      0.1972,      0.1982,      0.1992,      0.2002,      0.2012,      0.2022,      0.2032,      0.2042,     0.20521,     0.20621,     0.20721,     0.20821,     0.20921,     0.21021,     0.21121,     0.21221,     0.21321,     0.21421,     0.21522,
           0.21622,     0.21722,     0.21822,     0.21922,     0.22022,     0.22122,     0.22222,     0.22322,     0.22422,     0.22523,     0.22623,     0.22723,     0.22823,     0.22923,     0.23023,     0.23123,     0.23223,     0.23323,     0.23423,     0.23524,     0.23624,     0.23724,     0.23824,     0.23924,
           0.24024,     0.24124,     0.24224,     0.24324,     0.24424,     0.24525,     0.24625,     0.24725,     0.24825,     0.24925,     0.25025,     0.25125,     0.25225,     0.25325,     0.25425,     0.25526,     0.25626,     0.25726,     0.25826,     0.25926,     0.26026,     0.26126,     0.26226,     0.26326,
           0.26426,     0.26527,     0.26627,     0.26727,     0.26827,     0.26927,     0.27027,     0.27127,     0.27227,     0.27327,     0.27427,     0.27528,     0.27628,     0.27728,     0.27828,     0.27928,     0.28028,     0.28128,     0.28228,     0.28328,     0.28428,     0.28529,     0.28629,     0.28729,
           0.28829,     0.28929,     0.29029,     0.29129,     0.29229,     0.29329,     0.29429,      0.2953,      0.2963,      0.2973,      0.2983,      0.2993,      0.3003,      0.3013,      0.3023,      0.3033,      0.3043,     0.30531,     0.30631,     0.30731,     0.30831,     0.30931,     0.31031,     0.31131,
           0.31231,     0.31331,     0.31431,     0.31532,     0.31632,     0.31732,     0.31832,     0.31932,     0.32032,     0.32132,     0.32232,     0.32332,     0.32432,     0.32533,     0.32633,     0.32733,     0.32833,     0.32933,     0.33033,     0.33133,     0.33233,     0.33333,     0.33433,     0.33534,
           0.33634,     0.33734,     0.33834,     0.33934,     0.34034,     0.34134,     0.34234,     0.34334,     0.34434,     0.34535,     0.34635,     0.34735,     0.34835,     0.34935,     0.35035,     0.35135,     0.35235,     0.35335,     0.35435,     0.35536,     0.35636,     0.35736,     0.35836,     0.35936,
           0.36036,     0.36136,     0.36236,     0.36336,     0.36436,     0.36537,     0.36637,     0.36737,     0.36837,     0.36937,     0.37037,     0.37137,     0.37237,     0.37337,     0.37437,     0.37538,     0.37638,     0.37738,     0.37838,     0.37938,     0.38038,     0.38138,     0.38238,     0.38338,
           0.38438,     0.38539,     0.38639,     0.38739,     0.38839,     0.38939,     0.39039,     0.39139,     0.39239,     0.39339,     0.39439,      0.3954,      0.3964,      0.3974,      0.3984,      0.3994,      0.4004,      0.4014,      0.4024,      0.4034,      0.4044,     0.40541,     0.40641,     0.40741,
           0.40841,     0.40941,     0.41041,     0.41141,     0.41241,     0.41341,     0.41441,     0.41542,     0.41642,     0.41742,     0.41842,     0.41942,     0.42042,     0.42142,     0.42242,     0.42342,     0.42442,     0.42543,     0.42643,     0.42743,     0.42843,     0.42943,     0.43043,     0.43143,
           0.43243,     0.43343,     0.43443,     0.43544,     0.43644,     0.43744,     0.43844,     0.43944,     0.44044,     0.44144,     0.44244,     0.44344,     0.44444,     0.44545,     0.44645,     0.44745,     0.44845,     0.44945,     0.45045,     0.45145,     0.45245,     0.45345,     0.45445,     0.45546,
           0.45646,     0.45746,     0.45846,     0.45946,     0.46046,     0.46146,     0.46246,     0.46346,     0.46446,     0.46547,     0.46647,     0.46747,     0.46847,     0.46947,     0.47047,     0.47147,     0.47247,     0.47347,     0.47447,     0.47548,     0.47648,     0.47748,     0.47848,     0.47948,
           0.48048,     0.48148,     0.48248,     0.48348,     0.48448,     0.48549,     0.48649,     0.48749,     0.48849,     0.48949,     0.49049,     0.49149,     0.49249,     0.49349,     0.49449,      0.4955,      0.4965,      0.4975,      0.4985,      0.4995,      0.5005,      0.5015,      0.5025,      0.5035,
            0.5045,     0.50551,     0.50651,     0.50751,     0.50851,     0.50951,     0.51051,     0.51151,     0.51251,     0.51351,     0.51451,     0.51552,     0.51652,     0.51752,     0.51852,     0.51952,     0.52052,     0.52152,     0.52252,     0.52352,     0.52452,     0.52553,     0.52653,     0.52753,
           0.52853,     0.52953,     0.53053,     0.53153,     0.53253,     0.53353,     0.53453,     0.53554,     0.53654,     0.53754,     0.53854,     0.53954,     0.54054,     0.54154,     0.54254,     0.54354,     0.54454,     0.54555,     0.54655,     0.54755,     0.54855,     0.54955,     0.55055,     0.55155,
           0.55255,     0.55355,     0.55455,     0.55556,     0.55656,     0.55756,     0.55856,     0.55956,     0.56056,     0.56156,     0.56256,     0.56356,     0.56456,     0.56557,     0.56657,     0.56757,     0.56857,     0.56957,     0.57057,     0.57157,     0.57257,     0.57357,     0.57457,     0.57558,
           0.57658,     0.57758,     0.57858,     0.57958,     0.58058,     0.58158,     0.58258,     0.58358,     0.58458,     0.58559,     0.58659,     0.58759,     0.58859,     0.58959,     0.59059,     0.59159,     0.59259,     0.59359,     0.59459,      0.5956,      0.5966,      0.5976,      0.5986,      0.5996,
            0.6006,      0.6016,      0.6026,      0.6036,      0.6046,     0.60561,     0.60661,     0.60761,     0.60861,     0.60961,     0.61061,     0.61161,     0.61261,     0.61361,     0.61461,     0.61562,     0.61662,     0.61762,     0.61862,     0.61962,     0.62062,     0.62162,     0.62262,     0.62362,
           0.62462,     0.62563,     0.62663,     0.62763,     0.62863,     0.62963,     0.63063,     0.63163,     0.63263,     0.63363,     0.63463,     0.63564,     0.63664,     0.63764,     0.63864,     0.63964,     0.64064,     0.64164,     0.64264,     0.64364,     0.64464,     0.64565,     0.64665,     0.64765,
           0.64865,     0.64965,     0.65065,     0.65165,     0.65265,     0.65365,     0.65465,     0.65566,     0.65666,     0.65766,     0.65866,     0.65966,     0.66066,     0.66166,     0.66266,     0.66366,     0.66466,     0.66567,     0.66667,     0.66767,     0.66867,     0.66967,     0.67067,     0.67167,
           0.67267,     0.67367,     0.67467,     0.67568,     0.67668,     0.67768,     0.67868,     0.67968,     0.68068,     0.68168,     0.68268,     0.68368,     0.68468,     0.68569,     0.68669,     0.68769,     0.68869,     0.68969,     0.69069,     0.69169,     0.69269,     0.69369,     0.69469,      0.6957,
            0.6967,      0.6977,      0.6987,      0.6997,      0.7007,      0.7017,      0.7027,      0.7037,      0.7047,     0.70571,     0.70671,     0.70771,     0.70871,     0.70971,     0.71071,     0.71171,     0.71271,     0.71371,     0.71471,     0.71572,     0.71672,     0.71772,     0.71872,     0.71972,
           0.72072,     0.72172,     0.72272,     0.72372,     0.72472,     0.72573,     0.72673,     0.72773,     0.72873,     0.72973,     0.73073,     0.73173,     0.73273,     0.73373,     0.73473,     0.73574,     0.73674,     0.73774,     0.73874,     0.73974,     0.74074,     0.74174,     0.74274,     0.74374,
           0.74474,     0.74575,     0.74675,     0.74775,     0.74875,     0.74975,     0.75075,     0.75175,     0.75275,     0.75375,     0.75475,     0.75576,     0.75676,     0.75776,     0.75876,     0.75976,     0.76076,     0.76176,     0.76276,     0.76376,     0.76476,     0.76577,     0.76677,     0.76777,
           0.76877,     0.76977,     0.77077,     0.77177,     0.77277,     0.77377,     0.77477,     0.77578,     0.77678,     0.77778,     0.77878,     0.77978,     0.78078,     0.78178,     0.78278,     0.78378,     0.78478,     0.78579,     0.78679,     0.78779,     0.78879,     0.78979,     0.79079,     0.79179,
           0.79279,     0.79379,     0.79479,      0.7958,      0.7968,      0.7978,      0.7988,      0.7998,      0.8008,      0.8018,      0.8028,      0.8038,      0.8048,     0.80581,     0.80681,     0.80781,     0.80881,     0.80981,     0.81081,     0.81181,     0.81281,     0.81381,     0.81481,     0.81582,
           0.81682,     0.81782,     0.81882,     0.81982,     0.82082,     0.82182,     0.82282,     0.82382,     0.82482,     0.82583,     0.82683,     0.82783,     0.82883,     0.82983,     0.83083,     0.83183,     0.83283,     0.83383,     0.83483,     0.83584,     0.83684,     0.83784,     0.83884,     0.83984,
           0.84084,     0.84184,     0.84284,     0.84384,     0.84484,     0.84585,     0.84685,     0.84785,     0.84885,     0.84985,     0.85085,     0.85185,     0.85285,     0.85385,     0.85485,     0.85586,     0.85686,     0.85786,     0.85886,     0.85986,     0.86086,     0.86186,     0.86286,     0.86386,
           0.86486,     0.86587,     0.86687,     0.86787,     0.86887,     0.86987,     0.87087,     0.87187,     0.87287,     0.87387,     0.87487,     0.87588,     0.87688,     0.87788,     0.87888,     0.87988,     0.88088,     0.88188,     0.88288,     0.88388,     0.88488,     0.88589,     0.88689,     0.88789,
           0.88889,     0.88989,     0.89089,     0.89189,     0.89289,     0.89389,     0.89489,      0.8959,      0.8969,      0.8979,      0.8989,      0.8999,      0.9009,      0.9019,      0.9029,      0.9039,      0.9049,     0.90591,     0.90691,     0.90791,     0.90891,     0.90991,     0.91091,     0.91191,
           0.91291,     0.91391,     0.91491,     0.91592,     0.91692,     0.91792,     0.91892,     0.91992,     0.92092,     0.92192,     0.92292,     0.92392,     0.92492,     0.92593,     0.92693,     0.92793,     0.92893,     0.92993,     0.93093,     0.93193,     0.93293,     0.93393,     0.93493,     0.93594,
           0.93694,     0.93794,     0.93894,     0.93994,     0.94094,     0.94194,     0.94294,     0.94394,     0.94494,     0.94595,     0.94695,     0.94795,     0.94895,     0.94995,     0.95095,     0.95195,     0.95295,     0.95395,     0.95495,     0.95596,     0.95696,     0.95796,     0.95896,     0.95996,
           0.96096,     0.96196,     0.96296,     0.96396,     0.96496,     0.96597,     0.96697,     0.96797,     0.96897,     0.96997,     0.97097,     0.97197,     0.97297,     0.97397,     0.97497,     0.97598,     0.97698,     0.97798,     0.97898,     0.97998,     0.98098,     0.98198,     0.98298,     0.98398,
           0.98498,     0.98599,     0.98699,     0.98799,     0.98899,     0.98999,     0.99099,     0.99199,     0.99299,     0.99399,     0.99499,       0.996,       0.997,       0.998,       0.999,           1]), array([[    0.25284,     0.25284,     0.35988, ...,           0,           0,           0],
       [    0.20776,     0.20776,     0.30045, ...,           0,           0,           0],
       [    0.42967,     0.42967,      0.5577, ...,           0,           0,           0],
       [    0.23629,     0.23629,     0.30808, ...,           0,           0,           0]]), 'Confidence', 'F1'], [array([          0,    0.001001,    0.002002,    0.003003,    0.004004,    0.005005,    0.006006,    0.007007,    0.008008,    0.009009,     0.01001,    0.011011,    0.012012,    0.013013,    0.014014,    0.015015,    0.016016,    0.017017,    0.018018,    0.019019,     0.02002,    0.021021,    0.022022,    0.023023,
          0.024024,    0.025025,    0.026026,    0.027027,    0.028028,    0.029029,     0.03003,    0.031031,    0.032032,    0.033033,    0.034034,    0.035035,    0.036036,    0.037037,    0.038038,    0.039039,     0.04004,    0.041041,    0.042042,    0.043043,    0.044044,    0.045045,    0.046046,    0.047047,
          0.048048,    0.049049,     0.05005,    0.051051,    0.052052,    0.053053,    0.054054,    0.055055,    0.056056,    0.057057,    0.058058,    0.059059,     0.06006,    0.061061,    0.062062,    0.063063,    0.064064,    0.065065,    0.066066,    0.067067,    0.068068,    0.069069,     0.07007,    0.071071,
          0.072072,    0.073073,    0.074074,    0.075075,    0.076076,    0.077077,    0.078078,    0.079079,     0.08008,    0.081081,    0.082082,    0.083083,    0.084084,    0.085085,    0.086086,    0.087087,    0.088088,    0.089089,     0.09009,    0.091091,    0.092092,    0.093093,    0.094094,    0.095095,
          0.096096,    0.097097,    0.098098,    0.099099,      0.1001,      0.1011,      0.1021,      0.1031,      0.1041,     0.10511,     0.10611,     0.10711,     0.10811,     0.10911,     0.11011,     0.11111,     0.11211,     0.11311,     0.11411,     0.11512,     0.11612,     0.11712,     0.11812,     0.11912,
           0.12012,     0.12112,     0.12212,     0.12312,     0.12412,     0.12513,     0.12613,     0.12713,     0.12813,     0.12913,     0.13013,     0.13113,     0.13213,     0.13313,     0.13413,     0.13514,     0.13614,     0.13714,     0.13814,     0.13914,     0.14014,     0.14114,     0.14214,     0.14314,
           0.14414,     0.14515,     0.14615,     0.14715,     0.14815,     0.14915,     0.15015,     0.15115,     0.15215,     0.15315,     0.15415,     0.15516,     0.15616,     0.15716,     0.15816,     0.15916,     0.16016,     0.16116,     0.16216,     0.16316,     0.16416,     0.16517,     0.16617,     0.16717,
           0.16817,     0.16917,     0.17017,     0.17117,     0.17217,     0.17317,     0.17417,     0.17518,     0.17618,     0.17718,     0.17818,     0.17918,     0.18018,     0.18118,     0.18218,     0.18318,     0.18418,     0.18519,     0.18619,     0.18719,     0.18819,     0.18919,     0.19019,     0.19119,
           0.19219,     0.19319,     0.19419,      0.1952,      0.1962,      0.1972,      0.1982,      0.1992,      0.2002,      0.2012,      0.2022,      0.2032,      0.2042,     0.20521,     0.20621,     0.20721,     0.20821,     0.20921,     0.21021,     0.21121,     0.21221,     0.21321,     0.21421,     0.21522,
           0.21622,     0.21722,     0.21822,     0.21922,     0.22022,     0.22122,     0.22222,     0.22322,     0.22422,     0.22523,     0.22623,     0.22723,     0.22823,     0.22923,     0.23023,     0.23123,     0.23223,     0.23323,     0.23423,     0.23524,     0.23624,     0.23724,     0.23824,     0.23924,
           0.24024,     0.24124,     0.24224,     0.24324,     0.24424,     0.24525,     0.24625,     0.24725,     0.24825,     0.24925,     0.25025,     0.25125,     0.25225,     0.25325,     0.25425,     0.25526,     0.25626,     0.25726,     0.25826,     0.25926,     0.26026,     0.26126,     0.26226,     0.26326,
           0.26426,     0.26527,     0.26627,     0.26727,     0.26827,     0.26927,     0.27027,     0.27127,     0.27227,     0.27327,     0.27427,     0.27528,     0.27628,     0.27728,     0.27828,     0.27928,     0.28028,     0.28128,     0.28228,     0.28328,     0.28428,     0.28529,     0.28629,     0.28729,
           0.28829,     0.28929,     0.29029,     0.29129,     0.29229,     0.29329,     0.29429,      0.2953,      0.2963,      0.2973,      0.2983,      0.2993,      0.3003,      0.3013,      0.3023,      0.3033,      0.3043,     0.30531,     0.30631,     0.30731,     0.30831,     0.30931,     0.31031,     0.31131,
           0.31231,     0.31331,     0.31431,     0.31532,     0.31632,     0.31732,     0.31832,     0.31932,     0.32032,     0.32132,     0.32232,     0.32332,     0.32432,     0.32533,     0.32633,     0.32733,     0.32833,     0.32933,     0.33033,     0.33133,     0.33233,     0.33333,     0.33433,     0.33534,
           0.33634,     0.33734,     0.33834,     0.33934,     0.34034,     0.34134,     0.34234,     0.34334,     0.34434,     0.34535,     0.34635,     0.34735,     0.34835,     0.34935,     0.35035,     0.35135,     0.35235,     0.35335,     0.35435,     0.35536,     0.35636,     0.35736,     0.35836,     0.35936,
           0.36036,     0.36136,     0.36236,     0.36336,     0.36436,     0.36537,     0.36637,     0.36737,     0.36837,     0.36937,     0.37037,     0.37137,     0.37237,     0.37337,     0.37437,     0.37538,     0.37638,     0.37738,     0.37838,     0.37938,     0.38038,     0.38138,     0.38238,     0.38338,
           0.38438,     0.38539,     0.38639,     0.38739,     0.38839,     0.38939,     0.39039,     0.39139,     0.39239,     0.39339,     0.39439,      0.3954,      0.3964,      0.3974,      0.3984,      0.3994,      0.4004,      0.4014,      0.4024,      0.4034,      0.4044,     0.40541,     0.40641,     0.40741,
           0.40841,     0.40941,     0.41041,     0.41141,     0.41241,     0.41341,     0.41441,     0.41542,     0.41642,     0.41742,     0.41842,     0.41942,     0.42042,     0.42142,     0.42242,     0.42342,     0.42442,     0.42543,     0.42643,     0.42743,     0.42843,     0.42943,     0.43043,     0.43143,
           0.43243,     0.43343,     0.43443,     0.43544,     0.43644,     0.43744,     0.43844,     0.43944,     0.44044,     0.44144,     0.44244,     0.44344,     0.44444,     0.44545,     0.44645,     0.44745,     0.44845,     0.44945,     0.45045,     0.45145,     0.45245,     0.45345,     0.45445,     0.45546,
           0.45646,     0.45746,     0.45846,     0.45946,     0.46046,     0.46146,     0.46246,     0.46346,     0.46446,     0.46547,     0.46647,     0.46747,     0.46847,     0.46947,     0.47047,     0.47147,     0.47247,     0.47347,     0.47447,     0.47548,     0.47648,     0.47748,     0.47848,     0.47948,
           0.48048,     0.48148,     0.48248,     0.48348,     0.48448,     0.48549,     0.48649,     0.48749,     0.48849,     0.48949,     0.49049,     0.49149,     0.49249,     0.49349,     0.49449,      0.4955,      0.4965,      0.4975,      0.4985,      0.4995,      0.5005,      0.5015,      0.5025,      0.5035,
            0.5045,     0.50551,     0.50651,     0.50751,     0.50851,     0.50951,     0.51051,     0.51151,     0.51251,     0.51351,     0.51451,     0.51552,     0.51652,     0.51752,     0.51852,     0.51952,     0.52052,     0.52152,     0.52252,     0.52352,     0.52452,     0.52553,     0.52653,     0.52753,
           0.52853,     0.52953,     0.53053,     0.53153,     0.53253,     0.53353,     0.53453,     0.53554,     0.53654,     0.53754,     0.53854,     0.53954,     0.54054,     0.54154,     0.54254,     0.54354,     0.54454,     0.54555,     0.54655,     0.54755,     0.54855,     0.54955,     0.55055,     0.55155,
           0.55255,     0.55355,     0.55455,     0.55556,     0.55656,     0.55756,     0.55856,     0.55956,     0.56056,     0.56156,     0.56256,     0.56356,     0.56456,     0.56557,     0.56657,     0.56757,     0.56857,     0.56957,     0.57057,     0.57157,     0.57257,     0.57357,     0.57457,     0.57558,
           0.57658,     0.57758,     0.57858,     0.57958,     0.58058,     0.58158,     0.58258,     0.58358,     0.58458,     0.58559,     0.58659,     0.58759,     0.58859,     0.58959,     0.59059,     0.59159,     0.59259,     0.59359,     0.59459,      0.5956,      0.5966,      0.5976,      0.5986,      0.5996,
            0.6006,      0.6016,      0.6026,      0.6036,      0.6046,     0.60561,     0.60661,     0.60761,     0.60861,     0.60961,     0.61061,     0.61161,     0.61261,     0.61361,     0.61461,     0.61562,     0.61662,     0.61762,     0.61862,     0.61962,     0.62062,     0.62162,     0.62262,     0.62362,
           0.62462,     0.62563,     0.62663,     0.62763,     0.62863,     0.62963,     0.63063,     0.63163,     0.63263,     0.63363,     0.63463,     0.63564,     0.63664,     0.63764,     0.63864,     0.63964,     0.64064,     0.64164,     0.64264,     0.64364,     0.64464,     0.64565,     0.64665,     0.64765,
           0.64865,     0.64965,     0.65065,     0.65165,     0.65265,     0.65365,     0.65465,     0.65566,     0.65666,     0.65766,     0.65866,     0.65966,     0.66066,     0.66166,     0.66266,     0.66366,     0.66466,     0.66567,     0.66667,     0.66767,     0.66867,     0.66967,     0.67067,     0.67167,
           0.67267,     0.67367,     0.67467,     0.67568,     0.67668,     0.67768,     0.67868,     0.67968,     0.68068,     0.68168,     0.68268,     0.68368,     0.68468,     0.68569,     0.68669,     0.68769,     0.68869,     0.68969,     0.69069,     0.69169,     0.69269,     0.69369,     0.69469,      0.6957,
            0.6967,      0.6977,      0.6987,      0.6997,      0.7007,      0.7017,      0.7027,      0.7037,      0.7047,     0.70571,     0.70671,     0.70771,     0.70871,     0.70971,     0.71071,     0.71171,     0.71271,     0.71371,     0.71471,     0.71572,     0.71672,     0.71772,     0.71872,     0.71972,
           0.72072,     0.72172,     0.72272,     0.72372,     0.72472,     0.72573,     0.72673,     0.72773,     0.72873,     0.72973,     0.73073,     0.73173,     0.73273,     0.73373,     0.73473,     0.73574,     0.73674,     0.73774,     0.73874,     0.73974,     0.74074,     0.74174,     0.74274,     0.74374,
           0.74474,     0.74575,     0.74675,     0.74775,     0.74875,     0.74975,     0.75075,     0.75175,     0.75275,     0.75375,     0.75475,     0.75576,     0.75676,     0.75776,     0.75876,     0.75976,     0.76076,     0.76176,     0.76276,     0.76376,     0.76476,     0.76577,     0.76677,     0.76777,
           0.76877,     0.76977,     0.77077,     0.77177,     0.77277,     0.77377,     0.77477,     0.77578,     0.77678,     0.77778,     0.77878,     0.77978,     0.78078,     0.78178,     0.78278,     0.78378,     0.78478,     0.78579,     0.78679,     0.78779,     0.78879,     0.78979,     0.79079,     0.79179,
           0.79279,     0.79379,     0.79479,      0.7958,      0.7968,      0.7978,      0.7988,      0.7998,      0.8008,      0.8018,      0.8028,      0.8038,      0.8048,     0.80581,     0.80681,     0.80781,     0.80881,     0.80981,     0.81081,     0.81181,     0.81281,     0.81381,     0.81481,     0.81582,
           0.81682,     0.81782,     0.81882,     0.81982,     0.82082,     0.82182,     0.82282,     0.82382,     0.82482,     0.82583,     0.82683,     0.82783,     0.82883,     0.82983,     0.83083,     0.83183,     0.83283,     0.83383,     0.83483,     0.83584,     0.83684,     0.83784,     0.83884,     0.83984,
           0.84084,     0.84184,     0.84284,     0.84384,     0.84484,     0.84585,     0.84685,     0.84785,     0.84885,     0.84985,     0.85085,     0.85185,     0.85285,     0.85385,     0.85485,     0.85586,     0.85686,     0.85786,     0.85886,     0.85986,     0.86086,     0.86186,     0.86286,     0.86386,
           0.86486,     0.86587,     0.86687,     0.86787,     0.86887,     0.86987,     0.87087,     0.87187,     0.87287,     0.87387,     0.87487,     0.87588,     0.87688,     0.87788,     0.87888,     0.87988,     0.88088,     0.88188,     0.88288,     0.88388,     0.88488,     0.88589,     0.88689,     0.88789,
           0.88889,     0.88989,     0.89089,     0.89189,     0.89289,     0.89389,     0.89489,      0.8959,      0.8969,      0.8979,      0.8989,      0.8999,      0.9009,      0.9019,      0.9029,      0.9039,      0.9049,     0.90591,     0.90691,     0.90791,     0.90891,     0.90991,     0.91091,     0.91191,
           0.91291,     0.91391,     0.91491,     0.91592,     0.91692,     0.91792,     0.91892,     0.91992,     0.92092,     0.92192,     0.92292,     0.92392,     0.92492,     0.92593,     0.92693,     0.92793,     0.92893,     0.92993,     0.93093,     0.93193,     0.93293,     0.93393,     0.93493,     0.93594,
           0.93694,     0.93794,     0.93894,     0.93994,     0.94094,     0.94194,     0.94294,     0.94394,     0.94494,     0.94595,     0.94695,     0.94795,     0.94895,     0.94995,     0.95095,     0.95195,     0.95295,     0.95395,     0.95495,     0.95596,     0.95696,     0.95796,     0.95896,     0.95996,
           0.96096,     0.96196,     0.96296,     0.96396,     0.96496,     0.96597,     0.96697,     0.96797,     0.96897,     0.96997,     0.97097,     0.97197,     0.97297,     0.97397,     0.97497,     0.97598,     0.97698,     0.97798,     0.97898,     0.97998,     0.98098,     0.98198,     0.98298,     0.98398,
           0.98498,     0.98599,     0.98699,     0.98799,     0.98899,     0.98999,     0.99099,     0.99199,     0.99299,     0.99399,     0.99499,       0.996,       0.997,       0.998,       0.999,           1]), array([[    0.14472,     0.14472,     0.21997, ...,           1,           1,           1],
       [    0.11592,     0.11592,     0.17713, ...,           1,           1,           1],
       [    0.27451,     0.27451,     0.38846, ...,           1,           1,           1],
       [    0.13429,     0.13429,      0.1833, ...,           1,           1,           1]]), 'Confidence', 'Precision'], [array([          0,    0.001001,    0.002002,    0.003003,    0.004004,    0.005005,    0.006006,    0.007007,    0.008008,    0.009009,     0.01001,    0.011011,    0.012012,    0.013013,    0.014014,    0.015015,    0.016016,    0.017017,    0.018018,    0.019019,     0.02002,    0.021021,    0.022022,    0.023023,
          0.024024,    0.025025,    0.026026,    0.027027,    0.028028,    0.029029,     0.03003,    0.031031,    0.032032,    0.033033,    0.034034,    0.035035,    0.036036,    0.037037,    0.038038,    0.039039,     0.04004,    0.041041,    0.042042,    0.043043,    0.044044,    0.045045,    0.046046,    0.047047,
          0.048048,    0.049049,     0.05005,    0.051051,    0.052052,    0.053053,    0.054054,    0.055055,    0.056056,    0.057057,    0.058058,    0.059059,     0.06006,    0.061061,    0.062062,    0.063063,    0.064064,    0.065065,    0.066066,    0.067067,    0.068068,    0.069069,     0.07007,    0.071071,
          0.072072,    0.073073,    0.074074,    0.075075,    0.076076,    0.077077,    0.078078,    0.079079,     0.08008,    0.081081,    0.082082,    0.083083,    0.084084,    0.085085,    0.086086,    0.087087,    0.088088,    0.089089,     0.09009,    0.091091,    0.092092,    0.093093,    0.094094,    0.095095,
          0.096096,    0.097097,    0.098098,    0.099099,      0.1001,      0.1011,      0.1021,      0.1031,      0.1041,     0.10511,     0.10611,     0.10711,     0.10811,     0.10911,     0.11011,     0.11111,     0.11211,     0.11311,     0.11411,     0.11512,     0.11612,     0.11712,     0.11812,     0.11912,
           0.12012,     0.12112,     0.12212,     0.12312,     0.12412,     0.12513,     0.12613,     0.12713,     0.12813,     0.12913,     0.13013,     0.13113,     0.13213,     0.13313,     0.13413,     0.13514,     0.13614,     0.13714,     0.13814,     0.13914,     0.14014,     0.14114,     0.14214,     0.14314,
           0.14414,     0.14515,     0.14615,     0.14715,     0.14815,     0.14915,     0.15015,     0.15115,     0.15215,     0.15315,     0.15415,     0.15516,     0.15616,     0.15716,     0.15816,     0.15916,     0.16016,     0.16116,     0.16216,     0.16316,     0.16416,     0.16517,     0.16617,     0.16717,
           0.16817,     0.16917,     0.17017,     0.17117,     0.17217,     0.17317,     0.17417,     0.17518,     0.17618,     0.17718,     0.17818,     0.17918,     0.18018,     0.18118,     0.18218,     0.18318,     0.18418,     0.18519,     0.18619,     0.18719,     0.18819,     0.18919,     0.19019,     0.19119,
           0.19219,     0.19319,     0.19419,      0.1952,      0.1962,      0.1972,      0.1982,      0.1992,      0.2002,      0.2012,      0.2022,      0.2032,      0.2042,     0.20521,     0.20621,     0.20721,     0.20821,     0.20921,     0.21021,     0.21121,     0.21221,     0.21321,     0.21421,     0.21522,
           0.21622,     0.21722,     0.21822,     0.21922,     0.22022,     0.22122,     0.22222,     0.22322,     0.22422,     0.22523,     0.22623,     0.22723,     0.22823,     0.22923,     0.23023,     0.23123,     0.23223,     0.23323,     0.23423,     0.23524,     0.23624,     0.23724,     0.23824,     0.23924,
           0.24024,     0.24124,     0.24224,     0.24324,     0.24424,     0.24525,     0.24625,     0.24725,     0.24825,     0.24925,     0.25025,     0.25125,     0.25225,     0.25325,     0.25425,     0.25526,     0.25626,     0.25726,     0.25826,     0.25926,     0.26026,     0.26126,     0.26226,     0.26326,
           0.26426,     0.26527,     0.26627,     0.26727,     0.26827,     0.26927,     0.27027,     0.27127,     0.27227,     0.27327,     0.27427,     0.27528,     0.27628,     0.27728,     0.27828,     0.27928,     0.28028,     0.28128,     0.28228,     0.28328,     0.28428,     0.28529,     0.28629,     0.28729,
           0.28829,     0.28929,     0.29029,     0.29129,     0.29229,     0.29329,     0.29429,      0.2953,      0.2963,      0.2973,      0.2983,      0.2993,      0.3003,      0.3013,      0.3023,      0.3033,      0.3043,     0.30531,     0.30631,     0.30731,     0.30831,     0.30931,     0.31031,     0.31131,
           0.31231,     0.31331,     0.31431,     0.31532,     0.31632,     0.31732,     0.31832,     0.31932,     0.32032,     0.32132,     0.32232,     0.32332,     0.32432,     0.32533,     0.32633,     0.32733,     0.32833,     0.32933,     0.33033,     0.33133,     0.33233,     0.33333,     0.33433,     0.33534,
           0.33634,     0.33734,     0.33834,     0.33934,     0.34034,     0.34134,     0.34234,     0.34334,     0.34434,     0.34535,     0.34635,     0.34735,     0.34835,     0.34935,     0.35035,     0.35135,     0.35235,     0.35335,     0.35435,     0.35536,     0.35636,     0.35736,     0.35836,     0.35936,
           0.36036,     0.36136,     0.36236,     0.36336,     0.36436,     0.36537,     0.36637,     0.36737,     0.36837,     0.36937,     0.37037,     0.37137,     0.37237,     0.37337,     0.37437,     0.37538,     0.37638,     0.37738,     0.37838,     0.37938,     0.38038,     0.38138,     0.38238,     0.38338,
           0.38438,     0.38539,     0.38639,     0.38739,     0.38839,     0.38939,     0.39039,     0.39139,     0.39239,     0.39339,     0.39439,      0.3954,      0.3964,      0.3974,      0.3984,      0.3994,      0.4004,      0.4014,      0.4024,      0.4034,      0.4044,     0.40541,     0.40641,     0.40741,
           0.40841,     0.40941,     0.41041,     0.41141,     0.41241,     0.41341,     0.41441,     0.41542,     0.41642,     0.41742,     0.41842,     0.41942,     0.42042,     0.42142,     0.42242,     0.42342,     0.42442,     0.42543,     0.42643,     0.42743,     0.42843,     0.42943,     0.43043,     0.43143,
           0.43243,     0.43343,     0.43443,     0.43544,     0.43644,     0.43744,     0.43844,     0.43944,     0.44044,     0.44144,     0.44244,     0.44344,     0.44444,     0.44545,     0.44645,     0.44745,     0.44845,     0.44945,     0.45045,     0.45145,     0.45245,     0.45345,     0.45445,     0.45546,
           0.45646,     0.45746,     0.45846,     0.45946,     0.46046,     0.46146,     0.46246,     0.46346,     0.46446,     0.46547,     0.46647,     0.46747,     0.46847,     0.46947,     0.47047,     0.47147,     0.47247,     0.47347,     0.47447,     0.47548,     0.47648,     0.47748,     0.47848,     0.47948,
           0.48048,     0.48148,     0.48248,     0.48348,     0.48448,     0.48549,     0.48649,     0.48749,     0.48849,     0.48949,     0.49049,     0.49149,     0.49249,     0.49349,     0.49449,      0.4955,      0.4965,      0.4975,      0.4985,      0.4995,      0.5005,      0.5015,      0.5025,      0.5035,
            0.5045,     0.50551,     0.50651,     0.50751,     0.50851,     0.50951,     0.51051,     0.51151,     0.51251,     0.51351,     0.51451,     0.51552,     0.51652,     0.51752,     0.51852,     0.51952,     0.52052,     0.52152,     0.52252,     0.52352,     0.52452,     0.52553,     0.52653,     0.52753,
           0.52853,     0.52953,     0.53053,     0.53153,     0.53253,     0.53353,     0.53453,     0.53554,     0.53654,     0.53754,     0.53854,     0.53954,     0.54054,     0.54154,     0.54254,     0.54354,     0.54454,     0.54555,     0.54655,     0.54755,     0.54855,     0.54955,     0.55055,     0.55155,
           0.55255,     0.55355,     0.55455,     0.55556,     0.55656,     0.55756,     0.55856,     0.55956,     0.56056,     0.56156,     0.56256,     0.56356,     0.56456,     0.56557,     0.56657,     0.56757,     0.56857,     0.56957,     0.57057,     0.57157,     0.57257,     0.57357,     0.57457,     0.57558,
           0.57658,     0.57758,     0.57858,     0.57958,     0.58058,     0.58158,     0.58258,     0.58358,     0.58458,     0.58559,     0.58659,     0.58759,     0.58859,     0.58959,     0.59059,     0.59159,     0.59259,     0.59359,     0.59459,      0.5956,      0.5966,      0.5976,      0.5986,      0.5996,
            0.6006,      0.6016,      0.6026,      0.6036,      0.6046,     0.60561,     0.60661,     0.60761,     0.60861,     0.60961,     0.61061,     0.61161,     0.61261,     0.61361,     0.61461,     0.61562,     0.61662,     0.61762,     0.61862,     0.61962,     0.62062,     0.62162,     0.62262,     0.62362,
           0.62462,     0.62563,     0.62663,     0.62763,     0.62863,     0.62963,     0.63063,     0.63163,     0.63263,     0.63363,     0.63463,     0.63564,     0.63664,     0.63764,     0.63864,     0.63964,     0.64064,     0.64164,     0.64264,     0.64364,     0.64464,     0.64565,     0.64665,     0.64765,
           0.64865,     0.64965,     0.65065,     0.65165,     0.65265,     0.65365,     0.65465,     0.65566,     0.65666,     0.65766,     0.65866,     0.65966,     0.66066,     0.66166,     0.66266,     0.66366,     0.66466,     0.66567,     0.66667,     0.66767,     0.66867,     0.66967,     0.67067,     0.67167,
           0.67267,     0.67367,     0.67467,     0.67568,     0.67668,     0.67768,     0.67868,     0.67968,     0.68068,     0.68168,     0.68268,     0.68368,     0.68468,     0.68569,     0.68669,     0.68769,     0.68869,     0.68969,     0.69069,     0.69169,     0.69269,     0.69369,     0.69469,      0.6957,
            0.6967,      0.6977,      0.6987,      0.6997,      0.7007,      0.7017,      0.7027,      0.7037,      0.7047,     0.70571,     0.70671,     0.70771,     0.70871,     0.70971,     0.71071,     0.71171,     0.71271,     0.71371,     0.71471,     0.71572,     0.71672,     0.71772,     0.71872,     0.71972,
           0.72072,     0.72172,     0.72272,     0.72372,     0.72472,     0.72573,     0.72673,     0.72773,     0.72873,     0.72973,     0.73073,     0.73173,     0.73273,     0.73373,     0.73473,     0.73574,     0.73674,     0.73774,     0.73874,     0.73974,     0.74074,     0.74174,     0.74274,     0.74374,
           0.74474,     0.74575,     0.74675,     0.74775,     0.74875,     0.74975,     0.75075,     0.75175,     0.75275,     0.75375,     0.75475,     0.75576,     0.75676,     0.75776,     0.75876,     0.75976,     0.76076,     0.76176,     0.76276,     0.76376,     0.76476,     0.76577,     0.76677,     0.76777,
           0.76877,     0.76977,     0.77077,     0.77177,     0.77277,     0.77377,     0.77477,     0.77578,     0.77678,     0.77778,     0.77878,     0.77978,     0.78078,     0.78178,     0.78278,     0.78378,     0.78478,     0.78579,     0.78679,     0.78779,     0.78879,     0.78979,     0.79079,     0.79179,
           0.79279,     0.79379,     0.79479,      0.7958,      0.7968,      0.7978,      0.7988,      0.7998,      0.8008,      0.8018,      0.8028,      0.8038,      0.8048,     0.80581,     0.80681,     0.80781,     0.80881,     0.80981,     0.81081,     0.81181,     0.81281,     0.81381,     0.81481,     0.81582,
           0.81682,     0.81782,     0.81882,     0.81982,     0.82082,     0.82182,     0.82282,     0.82382,     0.82482,     0.82583,     0.82683,     0.82783,     0.82883,     0.82983,     0.83083,     0.83183,     0.83283,     0.83383,     0.83483,     0.83584,     0.83684,     0.83784,     0.83884,     0.83984,
           0.84084,     0.84184,     0.84284,     0.84384,     0.84484,     0.84585,     0.84685,     0.84785,     0.84885,     0.84985,     0.85085,     0.85185,     0.85285,     0.85385,     0.85485,     0.85586,     0.85686,     0.85786,     0.85886,     0.85986,     0.86086,     0.86186,     0.86286,     0.86386,
           0.86486,     0.86587,     0.86687,     0.86787,     0.86887,     0.86987,     0.87087,     0.87187,     0.87287,     0.87387,     0.87487,     0.87588,     0.87688,     0.87788,     0.87888,     0.87988,     0.88088,     0.88188,     0.88288,     0.88388,     0.88488,     0.88589,     0.88689,     0.88789,
           0.88889,     0.88989,     0.89089,     0.89189,     0.89289,     0.89389,     0.89489,      0.8959,      0.8969,      0.8979,      0.8989,      0.8999,      0.9009,      0.9019,      0.9029,      0.9039,      0.9049,     0.90591,     0.90691,     0.90791,     0.90891,     0.90991,     0.91091,     0.91191,
           0.91291,     0.91391,     0.91491,     0.91592,     0.91692,     0.91792,     0.91892,     0.91992,     0.92092,     0.92192,     0.92292,     0.92392,     0.92492,     0.92593,     0.92693,     0.92793,     0.92893,     0.92993,     0.93093,     0.93193,     0.93293,     0.93393,     0.93493,     0.93594,
           0.93694,     0.93794,     0.93894,     0.93994,     0.94094,     0.94194,     0.94294,     0.94394,     0.94494,     0.94595,     0.94695,     0.94795,     0.94895,     0.94995,     0.95095,     0.95195,     0.95295,     0.95395,     0.95495,     0.95596,     0.95696,     0.95796,     0.95896,     0.95996,
           0.96096,     0.96196,     0.96296,     0.96396,     0.96496,     0.96597,     0.96697,     0.96797,     0.96897,     0.96997,     0.97097,     0.97197,     0.97297,     0.97397,     0.97497,     0.97598,     0.97698,     0.97798,     0.97898,     0.97998,     0.98098,     0.98198,     0.98298,     0.98398,
           0.98498,     0.98599,     0.98699,     0.98799,     0.98899,     0.98999,     0.99099,     0.99199,     0.99299,     0.99399,     0.99499,       0.996,       0.997,       0.998,       0.999,           1]), array([[          1,           1,     0.98876, ...,           0,           0,           0],
       [          1,           1,     0.98901, ...,           0,           0,           0],
       [    0.98824,     0.98824,     0.98824, ...,           0,           0,           0],
       [    0.98246,     0.98246,     0.96491, ...,           0,           0,           0]]), 'Confidence', 'Recall']]
fitness: 0.8210046188823464
keys: ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
maps: array([    0.84763,      0.7607,     0.86657,     0.75124])
names: {0: 'buffalo', 1: 'elephant', 2: 'rhino', 3: 'zebra'}
plot: True
results_dict: {'metrics/precision(B)': 0.9382732716357025, 'metrics/recall(B)': 0.8850679729846709, 'metrics/mAP50(B)': 0.9512262768368661, 'metrics/mAP50-95(B)': 0.8065355457762887, 'fitness': 0.8210046188823464}
save_dir: PosixPath('runs/detect/train')
speed: {'preprocess': 0.06935741778256165, 'inference': 1.2025215110689815, 'loss': 0.00013822236926191382, 'postprocess': 0.6716766088114431}
task: 'detect'
In [6]:
better_model = YOLO("runs/detect/train/weights/best.pt")

import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO

# Function to draw predicted bounding boxes
def draw_predictions(image_path):
    """Runs YOLOv11 inference and draws bounding boxes on the image."""
    # Load image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Run YOLOv11 model on the image
    results = better_model(image_rgb)

    # Draw bounding boxes
    for result in results:
        boxes = result.boxes.xyxy  # Bounding boxes (x1, y1, x2, y2)
        scores = result.boxes.conf  # Confidence scores
        labels = result.boxes.cls  # Class labels

        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = map(int, box)  # Convert to integers
            label = model.names[int(labels[i])]
            score = scores[i]

            # Draw bounding box
            cv2.rectangle(image_rgb, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_rgb, f"{label} {score:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

    return image_rgb  # Return processed image

# Plot 6 images with YOLOv11 predictions
fig, axes = plt.subplots(3, 3, figsize=(15, 15))

for ax, image_file in zip(axes.flatten(), selected_images):
    image_path = os.path.join(image_folder, image_file)
    
    # Process image with YOLOv11
    predicted_image = draw_predictions(image_path)
    
    # Display image
    ax.imshow(predicted_image)
    ax.set_title(f"Predictions for {image_file}")
    ax.axis("off")

plt.tight_layout()
plt.show()
0: 448x640 2 elephants, 8.4ms
Speed: 0.6ms preprocess, 8.4ms inference, 0.7ms postprocess per image at shape (1, 3, 448, 640)

0: 416x640 1 zebra, 8.7ms
Speed: 0.8ms preprocess, 8.7ms inference, 0.6ms postprocess per image at shape (1, 3, 416, 640)

0: 480x640 1 rhino, 8.5ms
Speed: 1.0ms preprocess, 8.5ms inference, 0.6ms postprocess per image at shape (1, 3, 480, 640)

0: 512x640 2 zebras, 8.5ms
Speed: 0.9ms preprocess, 8.5ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 448x640 1 rhino, 7.5ms
Speed: 0.8ms preprocess, 7.5ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640)

0: 448x640 1 buffalo, 7.1ms
Speed: 0.9ms preprocess, 7.1ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640)

0: 448x640 1 zebra, 7.1ms
Speed: 0.6ms preprocess, 7.1ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640)

0: 640x640 6 rhinos, 8.8ms
Speed: 1.0ms preprocess, 8.8ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 640)

0: 640x448 1 elephant, 8.9ms
Speed: 0.6ms preprocess, 8.9ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 448)
No description has been provided for this image

Semantic Segmentation with a Custom U-Net¶

In [37]:
import pandas as pd
import numpy as np
import os

#Organize oxford pets into folder structures

np.random.seed(42)

# Get list of image files
image_folder = "data/oxford-iiit-pet/images"
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]

# Create dataframe
df = pd.DataFrame(image_files, columns=['filename'])

# Add label column
df['label'] = df['filename'].apply(lambda x: 'cat' if x[0].isupper() else 'dog')

# Add set column
df['set'] = np.random.choice([1, 0], size=len(df), p=[0.5, 0.5])

print(df.head())
                           filename label  set
0        german_shorthaired_184.jpg   dog    1
1                    Birman_120.jpg   cat    0
2             great_pyrenees_20.jpg   dog    0
3                     samoyed_6.jpg   dog    0
4  american_pit_bull_terrier_28.jpg   dog    1
In [39]:
import os

# Define the base path
base_path = "data/oxford_pets"

# Define the folder structure
folders = [
    "train/cats", "train/dogs",
    "valid/cats", "valid/dogs",
    "train_trimaps/cats", "train_trimaps/dogs",
    "valid_trimaps/cats", "valid_trimaps/dogs"
]

# Create the folders
for folder in folders:
    os.makedirs(os.path.join(base_path, folder), exist_ok=True)

print("Folder structure created successfully.")
Folder structure created successfully.
In [40]:
import shutil

# Define the trimap folder
trimap_folder = "data/oxford-iiit-pet/annotations/trimaps"

# Function to copy files to the appropriate folder
def copy_files(row):
    image_file = row['filename']
    label = row['label']
    set_type = 'train' if row['set'] == 1 else 'valid'
    
    # Define source and destination paths
    image_src = os.path.join(image_folder, image_file)
    trimap_src = os.path.join(trimap_folder, image_file.replace('.jpg', '.png'))
    
    if os.path.exists(trimap_src):
        image_dst = os.path.join(base_path, f"{set_type}/{label}s", image_file)
        trimap_dst = os.path.join(base_path, f"{set_type}_trimaps/{label}s", image_file.replace('.jpg', '.png'))
        
        # Copy image and trimap
        shutil.copy(image_src, image_dst)
        shutil.copy(trimap_src, trimap_dst)

# Apply the function to each row in the dataframe
df.apply(copy_files, axis=1)

print("Files copied successfully.")
Files copied successfully.
In [41]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np

# Define class mapping
CLASS_MAP = {"cats": 0, "dogs": 1}  # 0 for cats, 1 for dogs

# Define a consistent image size
IMAGE_SIZE = (256, 256)

class OxfordPetsDataset(Dataset):
    def __init__(self, image_dir, trimap_dir, transform=None):
        self.image_dir = image_dir
        self.trimap_dir = trimap_dir
        self.transform = transform

        # List all image files
        self.image_files = sorted(os.listdir(image_dir))

        # Define transformations to ensure consistent size for trimaps (without normalization)
        self.trimap_transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST),  # Resize while preserving labels
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get filename
        image_filename = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_filename)
        trimap_path = os.path.join(self.trimap_dir, image_filename.replace(".jpg", ".png"))

        # Load image and trimap
        image = Image.open(image_path).convert("RGB")
        trimap = Image.open(trimap_path).convert("L")  # Trimap is grayscale

        # Resize trimap to (256, 256)
        trimap = self.trimap_transform(trimap)
        trimap = np.array(trimap, dtype=np.uint8)  # Ensure integer encoding

        # Get class label from folder name (cats or dogs)
        class_name = os.path.basename(os.path.dirname(image_path))
        class_label = CLASS_MAP[class_name]  # 0 for cats, 1 for dogs

        # Modify trimap coding (3-class)
        new_trimap = np.zeros_like(trimap, dtype=np.uint8)  # Initialize new trimap
        new_trimap[trimap == 1] = 1  # Outline (both cats & dogs)
        new_trimap[trimap == 2] = 2  # Object (both cats & dogs)

        # Convert trimap to LongTensor before returning
        new_trimap = torch.tensor(new_trimap, dtype=torch.long)  # Ensure it's stored correctly

        # Apply transformations to image (but NOT trimap)
        if self.transform:
            image = self.transform(image)

        return image, new_trimap, torch.tensor(class_label, dtype=torch.long)
In [42]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms

# Define dataset paths for images and trimaps
train_cats_images = "data/oxford_pets/train/cats"
train_dogs_images = "data/oxford_pets/train/dogs"
train_cats_trimaps = "data/oxford_pets/train_trimaps/cats"
train_dogs_trimaps = "data/oxford_pets/train_trimaps/dogs"

valid_cats_images = "data/oxford_pets/valid/cats"
valid_dogs_images = "data/oxford_pets/valid/dogs"
valid_cats_trimaps = "data/oxford_pets/valid_trimaps/cats"
valid_dogs_trimaps = "data/oxford_pets/valid_trimaps/dogs"
In [43]:
# Define transformations (apply only to images)
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),  # Ensure consistent image size
    transforms.ToTensor(),  # Convert image to PyTorch tensor (normalizes to [0,1])
])

# Create Dataset Instances
train_cats_dataset = OxfordPetsDataset(train_cats_images, train_cats_trimaps, transform=data_transforms)
train_dogs_dataset = OxfordPetsDataset(train_dogs_images, train_dogs_trimaps, transform=data_transforms)

valid_cats_dataset = OxfordPetsDataset(valid_cats_images, valid_cats_trimaps, transform=data_transforms)
valid_dogs_dataset = OxfordPetsDataset(valid_dogs_images, valid_dogs_trimaps, transform=data_transforms)

# Merge datasets
from torch.utils.data import ConcatDataset
train_dataset = ConcatDataset([train_cats_dataset, train_dogs_dataset])
valid_dataset = ConcatDataset([valid_cats_dataset, valid_dogs_dataset])

# Create DataLoaders
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

print(f"Train set: {len(train_loader.dataset)} samples")
print(f"Valid set: {len(valid_loader.dataset)} samples")
Train set: 3723 samples
Valid set: 3667 samples
In [44]:
import matplotlib.pyplot as plt
import random
import numpy as np

# Select a random index from the validation dataset
random_idx = random.randint(0, len(valid_dataset) - 1)

# Get the image, trimap, and class label
image, trimap, class_label = valid_dataset[random_idx]

# Convert image tensor to NumPy format for visualization
img_np = image.permute(1, 2, 0).numpy()  # Convert from (C, H, W) to (H, W, C)

# Convert trimap tensor to NumPy
trimap_np = trimap.cpu().numpy()

np.unique(trimap_np)
Out[44]:
array([0, 1, 2])
In [45]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

# Define class-to-color mapping for trimap
trimap_colors = {
    0: (0, 0, 0),       # Background - Black
    1: (0, 0, 1),       # Dog Outline - Blue
    2: (0, 1, 0),       # Dog Object - Green
    3: (1, 0, 0),       # Cat Outline - Red
    4: (1, 1, 0),       # Cat Object - Yellow
}

# Create a ListedColormap for visualization
cmap = mcolors.ListedColormap([trimap_colors[i] for i in range(len(trimap_colors))])

# Get a batch of training data
images, trimaps, labels = next(iter(train_loader))

# Show the first 4 images and trimaps
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i in range(4):
    # Convert tensor image to NumPy format
    img_np = images[i].permute(1, 2, 0).cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())  # Normalize for display

    # Convert trimap tensor to NumPy
    trimap_np = trimaps[i].cpu().numpy()

    # Image display
    axes[0, i].imshow(img_np)
    axes[0, i].set_title(f"Class: {'Dog' if labels[i] == 1 else 'Cat'}")
    axes[0, i].axis("off")

    # Trimap display with custom colors
    axes[1, i].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
    axes[1, i].set_title("Trimap (Labeled)")
    axes[1, i].axis("off")

# Create legend for the trimap colors
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
cbar = plt.colorbar(mappable=plt.cm.ScalarMappable(cmap=cmap), cax=cbar_ax, ticks=np.arange(len(trimap_colors)))
cbar.ax.set_yticklabels(["Background", "Dog Outline", "Dog", "Cat Outline", "Cat"])  # Custom labels

plt.tight_layout()
plt.show()
/tmp/ipykernel_380213/3389472999.py:47: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()
No description has been provided for this image
In [46]:
import torch
import torch.nn as nn
import torchvision.models as models

class UNetMobileNet(nn.Module):
    def __init__(self, num_classes=5, pretrained=True, freeze_encoder=True, use_skip_connections=True):
        super(UNetMobileNet, self).__init__()

        self.use_skip_connections = use_skip_connections  # Toggle skip connections

        # Load MobileNetV2 as encoder
        mobilenet = models.mobilenet_v2(pretrained=pretrained)
        encoder_layers = list(mobilenet.features.children())

        # Encoder blocks
        self.encoder1 = nn.Sequential(*encoder_layers[:2])   # 128 × 128
        self.encoder2 = nn.Sequential(*encoder_layers[2:4])  # 64 × 64
        self.encoder3 = nn.Sequential(*encoder_layers[4:7])  # 32 × 32
        self.encoder4 = nn.Sequential(*encoder_layers[7:14]) # 16 × 16
        self.encoder5 = nn.Sequential(*encoder_layers[14:])  # 8 × 8

        # Freeze encoder if specified
        if freeze_encoder:
            for param in self.encoder1.parameters():
                param.requires_grad = False
            for param in self.encoder2.parameters():
                param.requires_grad = False
            for param in self.encoder3.parameters():
                param.requires_grad = False
            for param in self.encoder4.parameters():
                param.requires_grad = False
            for param in self.encoder5.parameters():
                param.requires_grad = False

        # Bottleneck layer
        self.bottleneck = nn.Conv2d(1280, 512, kernel_size=1)

        # Decoder (Transposed Convolutions for upsampling)
        self.decoder4 = self._upsample(512, 96)  # 16 × 16
        self.decoder3 = self._upsample(96, 32)   # 32 × 32
        self.decoder2 = self._upsample(32, 24)   # 64 × 64
        self.decoder1 = self._upsample(24, 16)   # 128 × 128

        # **Extra upsampling to reach 256 × 256**
        self.final_up = self._upsample(16, 16)   # 256 × 256

        self.final_conv = nn.Conv2d(16, num_classes, kernel_size=1)  # Final segmentation layer

    def _upsample(self, in_channels, out_channels):
        """Helper function to create an upsampling block using transposed convolutions."""
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)  # 128 × 128
        e2 = self.encoder2(e1) # 64 × 64
        e3 = self.encoder3(e2) # 32 × 32
        e4 = self.encoder4(e3) # 16 × 16
        e5 = self.encoder5(e4) # 8 × 8

        # Bottleneck layer
        b = self.bottleneck(e5)

        # Decoder with optional skip connections
        d4 = self.decoder4(b) + e4 if self.use_skip_connections else self.decoder4(b)  # 16 × 16
        d3 = self.decoder3(d4) + e3 if self.use_skip_connections else self.decoder3(d4)  # 32 × 32
        d2 = self.decoder2(d3) + e2 if self.use_skip_connections else self.decoder2(d3)  # 64 × 64
        d1 = self.decoder1(d2) + e1 if self.use_skip_connections else self.decoder1(d2)  # 128 × 128

        d0 = self.final_up(d1)  # **Extra upsampling to reach 256 × 256**

        # Final segmentation map
        return self.final_conv(d0)  # Shape: (batch_size, num_classes, 256, 256)
In [47]:
import networkx as nx
import matplotlib.pyplot as plt

def visualize_unet_with_correct_skips(use_skip_connections=True):
    """
    Creates a U-Net architecture visualization with input/output sizes and channels at each layer.
    
    Args:
        use_skip_connections (bool): Whether to show skip connections.
    
    Returns:
        fig: Matplotlib figure to be displayed inline in Jupyter Notebook.
    """
    G = nx.DiGraph()

    # Define encoder layers with (Height x Width, Channels)
    encoder_layers = {
        "Input (256x256, 3)": (0, 5),
        "Enc1 (128x128, 16)": (1, 4),
        "Enc2 (64x64, 24)": (2, 3),
        "Enc3 (32x32, 32)": (3, 2),
        "Enc4 (16x16, 96)": (4, 1),
        "Bottleneck (8x8, 512)": (5, 0)
    }

    # Define decoder layers with corresponding output sizes and channels
    decoder_layers = {
        "Dec4 (16x16, 96)": (6, 1),
        "Dec3 (32x32, 32)": (7, 2),
        "Dec2 (64x64, 24)": (8, 3),
        "Dec1 (128x128, 16)": (9, 4),
        "Final Up (256x256, 16)": (10, 5),
        "Output (256x256, 5)": (11, 5.5)
    }

    # Add encoder and decoder nodes
    for layer, pos in {**encoder_layers, **decoder_layers}.items():
        G.add_node(layer, pos=pos)

    # Connect encoder layers sequentially
    encoder_keys = list(encoder_layers.keys())
    for i in range(len(encoder_keys) - 1):
        G.add_edge(encoder_keys[i], encoder_keys[i + 1])

    # Connect bottleneck to decoder
    G.add_edge("Bottleneck (8x8, 512)", "Dec4 (16x16, 96)")

    # Connect decoder layers sequentially
    decoder_keys = list(decoder_layers.keys())
    for i in range(len(decoder_keys) - 2):
        G.add_edge(decoder_keys[i], decoder_keys[i + 1])

    # Connect final output
    G.add_edge("Final Up (256x256, 16)", "Output (256x256, 5)")

    # Add **correct** skip connections if enabled
    if use_skip_connections:
        skip_connections = {
            "Enc1 (128x128, 16)": "Dec1 (128x128, 16)",
            "Enc2 (64x64, 24)": "Dec2 (64x64, 24)",
            "Enc3 (32x32, 32)": "Dec3 (32x32, 32)",
            "Enc4 (16x16, 96)": "Dec4 (16x16, 96)"
        }
        for enc, dec in skip_connections.items():
            G.add_edge(enc, dec, color="red", style="dashed")

    # Extract positions for visualization
    pos = nx.get_node_attributes(G, "pos")

    # Create a Matplotlib figure
    fig, ax = plt.subplots(figsize=(12, 6))
    edges = G.edges()

    # Color edges differently for skip connections
    edge_colors = ["red" if G[u][v].get("color") == "red" else "black" for u, v in edges]
    edge_styles = ["dashed" if G[u][v].get("style") == "dashed" else "solid" for u, v in edges]

    # Draw nodes
    nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=2500, font_size=8, font_weight="bold", ax=ax)

    # Draw normal and skip edges separately to allow different styles
    for i, (u, v) in enumerate(edges):
        nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=edge_colors[i], style=edge_styles[i], width=2, ax=ax)

    ax.set_title(f"U-Net Architecture {'(With Skip Connections)' if use_skip_connections else '(Without Skip Connections)'}")
    
    return fig

# Display the images inline in Jupyter Notebook
fig1 = visualize_unet_with_correct_skips(use_skip_connections=True)  # With Skip Connections
fig2 = visualize_unet_with_correct_skips(use_skip_connections=False) # Without Skip Connections

# Show figures in Jupyter Notebook
plt.show()
No description has been provided for this image
No description has been provided for this image
In [48]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics

class UNetLightning(pl.LightningModule):
    def __init__(self, num_classes=5, lr=1e-3, freeze_encoder=True, use_skip_connections=True):
        super(UNetLightning, self).__init__()

        # Load U-Net with correct output shape (256x256)
        self.model = UNetMobileNet(
            num_classes=num_classes, 
            freeze_encoder=freeze_encoder, 
            use_skip_connections=use_skip_connections
        )

        # Loss function (Cross-entropy for multi-class segmentation)
        self.criterion = nn.CrossEntropyLoss()

        # IoU (Intersection over Union) metric
        self.iou = torchmetrics.JaccardIndex(task="multiclass", num_classes=num_classes)

        # Learning rate
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        """ Training step: Computes loss and IoU for segmentation. """
        images, trimaps, _ = batch  # Extract inputs and ground truth masks
        logits = self.model(images)  # Forward pass
        loss = self.criterion(logits, trimaps.squeeze(1).long())  # Squeeze trimap to shape (batch_size, height, width)

        # Compute IoU (Jaccard Index)
        iou = self.iou(torch.argmax(logits, dim=1), trimaps.squeeze(1).long())

        # Log loss and IoU for training
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_iou", iou, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        """ Validation step: Evaluates loss and IoU on the validation set. """
        images, trimaps, _ = batch
        logits = self.model(images)
        loss = self.criterion(logits, trimaps.squeeze(1).long())  # Squeeze trimap to shape (batch_size, height, width)

        # Compute IoU
        iou = self.iou(torch.argmax(logits, dim=1), trimaps.squeeze(1).long())

        # Log metrics for validation
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_iou", iou, prog_bar=True)

        return loss


    def predict_step(self, batch, batch_idx):
        """ Prediction step: Runs inference on new images. """
        images, _, _ = batch  # We only need images for prediction
        logits = self.model(images)  # Forward pass
        preds = torch.argmax(logits, dim=1)  # Get predicted class labels

        return preds  # Return predicted segmentation masks

    def configure_optimizers(self):
        """ Optimizer and Learning Rate Scheduler """
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)

        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
In [49]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='UNetNoSkips', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")

# Create the model instance
model = UNetLightning(use_skip_connections = False)

# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
    max_epochs=10,
    logger=csv_logger,
    callbacks=[early_stop_callback]
)

trainer.fit(model, train_loader, valid_loader)

# Save the final model state
trainer.save_checkpoint('logs/UNetNoSkips/final_model.ckpt')
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kmcalist/.local/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:269: Experiment logs directory logs/UNetNoSkips/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory logs/UNetNoSkips/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params | Mode 
-------------------------------------------------------------
0 | model     | UNetMobileNet          | 3.8 M  | train
1 | criterion | CrossEntropyLoss       | 0      | train
2 | iou       | MulticlassJaccardIndex | 0      | train
-------------------------------------------------------------
1.6 M     Trainable params
2.2 M     Non-trainable params
3.8 M     Total params
15.361    Total estimated model params size (MB)
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (15) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 1.304
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.245 >= min_delta = 0.0. New best score: 1.059
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.172 >= min_delta = 0.0. New best score: 0.887
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.138 >= min_delta = 0.0. New best score: 0.749
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.084 >= min_delta = 0.0. New best score: 0.665
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.132 >= min_delta = 0.0. New best score: 0.533
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.046 >= min_delta = 0.0. New best score: 0.487
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.479
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.033 >= min_delta = 0.0. New best score: 0.446
`Trainer.fit` stopped: `max_epochs=10` reached.
In [50]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='UNetSkips', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")

# Create the model instance
model = UNetLightning(use_skip_connections = True)

# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
    max_epochs=5,
    logger=csv_logger,
    callbacks=[early_stop_callback]
)

trainer.fit(model, train_loader, valid_loader)

# Save the final model state
trainer.save_checkpoint('logs/UNetSkips/final_model.ckpt')
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kmcalist/.local/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:269: Experiment logs directory logs/UNetSkips/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory logs/UNetSkips/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params | Mode 
-------------------------------------------------------------
0 | model     | UNetMobileNet          | 3.8 M  | train
1 | criterion | CrossEntropyLoss       | 0      | train
2 | iou       | MulticlassJaccardIndex | 0      | train
-------------------------------------------------------------
1.6 M     Trainable params
2.2 M     Non-trainable params
3.8 M     Total params
15.361    Total estimated model params size (MB)
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Training: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 1.336
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.289 >= min_delta = 0.0. New best score: 1.047
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.035 >= min_delta = 0.0. New best score: 1.012
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.984
Validation: |          | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.964
`Trainer.fit` stopped: `max_epochs=5` reached.
In [52]:
import torch
import pytorch_lightning as pl

# Load trained models
model_with_skips = UNetLightning.load_from_checkpoint(
    'logs/UNetSkips/final_model.ckpt',
    use_skip_connections=True  # Ensure it matches training
)

model_without_skips = UNetLightning.load_from_checkpoint(
    'logs/UNetNoSkips/final_model.ckpt',
    use_skip_connections=False  # Ensure it matches training
)

# Set models to evaluation mode
model_with_skips.eval()
model_without_skips.eval()

print("done")
done
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
In [53]:
import random

# Get the validation dataset from the DataLoader
valid_dataset = valid_loader.dataset

# Randomly select 10 unique indices
random_indices = random.sample(range(len(valid_dataset)), 10)

# Extract the corresponding images (without labels)
random_samples = [valid_dataset[i] for i in random_indices]

# Convert images to batch format for model prediction
images = torch.stack([sample[0] for sample in random_samples])  # Image tensors
trimaps = torch.stack([sample[1] for sample in random_samples])  # Ground truth trimaps
classes = torch.stack([sample[2] for sample in random_samples])
In [54]:
# Create DataLoader for prediction
pred_loader = torch.utils.data.DataLoader(list(zip(images, trimaps, classes)), batch_size=10)
In [55]:
# Define Lightning trainer (No training, just prediction mode)
trainer = pl.Trainer(accelerator="gpu" if torch.cuda.is_available() else "cpu")

# Run predictions with trainer.predict()
preds_with_skips = trainer.predict(model_with_skips, dataloaders=pred_loader)
preds_without_skips = trainer.predict(model_without_skips, dataloaders=pred_loader)

# Convert list of batch tensors to a single tensor
preds_with_skips = torch.cat(preds_with_skips, dim=0)
preds_without_skips = torch.cat(preds_without_skips, dim=0)

# Convert logits to class labels
preds_with_skips = torch.argmax(preds_with_skips, dim=1).cpu().numpy()
preds_without_skips = torch.argmax(preds_without_skips, dim=1).cpu().numpy()
trimaps = trimaps.squeeze(1).cpu().numpy()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-03-04 14:16:52.935266: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-04 14:16:52.935290: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-04 14:16:52.936145: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-04 14:16:52.940029: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-04 14:16:53.520495: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
Predicting: |          | 0/? [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting: |          | 0/? [00:00<?, ?it/s]
In [59]:
preds_without_skips
Out[59]:
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

# Define class-to-color mapping for trimap
trimap_colors = {
    0: (0, 0, 0),   # Background - Black
    1: (1, 0, 0),   # Outline - Red
    2: (0, 1, 0),   # Object - Green
}

# Create a ListedColormap for visualization
cmap = mcolors.ListedColormap([trimap_colors[i] for i in range(len(trimap_colors))])

# Show the first 10 images, ground truth trimaps, and predicted trimaps
fig, axes = plt.subplots(10, 4, figsize=(12, 30))
axes[0, 0].set_title("Input Image")
axes[0, 1].set_title("Ground Truth Trimap")
axes[0, 2].set_title("Prediction (With Skips)")
axes[0, 3].set_title("Prediction (No Skips)")

for i in range(10):
    # Convert image tensor to NumPy format
    img_np = images[i].permute(1, 2, 0).cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())  # Normalize for display

    # Convert trimap tensors to NumPy
    trimap_np = trimaps[i]
    pred_skips_np = preds_with_skips[i]
    pred_no_skips_np = preds_without_skips[i]

    axes[i, 0].imshow(img_np)
    axes[i, 0].axis("off")

    axes[i, 1].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
    axes[i, 1].axis("off")

    axes[i, 2].imshow(pred_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
    axes[i, 2].axis("off")

    axes[i, 3].imshow(pred_no_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
    axes[i, 3].axis("off")

plt.tight_layout()
plt.show()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[56], line 38
     35 axes[i, 1].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
     36 axes[i, 1].axis("off")
---> 38 axes[i, 2].imshow(pred_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
     39 axes[i, 2].axis("off")
     41 axes[i, 3].imshow(pred_no_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)

File ~/.local/lib/python3.10/site-packages/matplotlib/__init__.py:1521, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1518 @functools.wraps(func)
   1519 def inner(ax, *args, data=None, **kwargs):
   1520     if data is None:
-> 1521         return func(
   1522             ax,
   1523             *map(cbook.sanitize_sequence, args),
   1524             **{k: cbook.sanitize_sequence(v) for k, v in kwargs.items()})
   1526     bound = new_sig.bind(ax, *args, **kwargs)
   1527     auto_label = (bound.arguments.get(label_namer)
   1528                   or bound.kwargs.get(label_namer))

File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_axes.py:5945, in Axes.imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
   5942 if aspect is not None:
   5943     self.set_aspect(aspect)
-> 5945 im.set_data(X)
   5946 im.set_alpha(alpha)
   5947 if im.get_clip_path() is None:
   5948     # image does not already have clipping set, clip to Axes patch

File ~/.local/lib/python3.10/site-packages/matplotlib/image.py:675, in _ImageBase.set_data(self, A)
    673 if isinstance(A, PIL.Image.Image):
    674     A = pil_to_array(A)  # Needed e.g. to apply png palette.
--> 675 self._A = self._normalize_image_array(A)
    676 self._imcache = None
    677 self.stale = True

File ~/.local/lib/python3.10/site-packages/matplotlib/image.py:643, in _ImageBase._normalize_image_array(A)
    641     A = A.squeeze(-1)  # If just (M, N, 1), assume scalar and apply colormap.
    642 if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]):
--> 643     raise TypeError(f"Invalid shape {A.shape} for image data")
    644 if A.ndim == 3:
    645     # If the input data has values outside the valid range (after
    646     # normalisation), we issue a warning and then clip X to the bounds
    647     # - otherwise casting wraps extreme values, hiding outliers and
    648     # making reliable interpretation impossible.
    649     high = 255 if np.issubdtype(A.dtype, np.integer) else 1

TypeError: Invalid shape (256,) for image data
No description has been provided for this image